Merge pull request #25335 from lgeiger:deprecated-bazel-option
PiperOrigin-RevId: 232765856
diff --git a/.bazelrc b/.bazelrc
index c70c571..17285af 100644
--- a/.bazelrc
+++ b/.bazelrc
@@ -90,6 +90,12 @@
build:dynamic_kernels --define=dynamic_loaded_kernels=true
build:dynamic_kernels --copt=-DAUTOLOAD_DYNAMIC_KERNELS
+# Build TF with C++ 17 features.
+build:c++17 --cxxopt=-std=c++1z
+build:c++17 --cxxopt=-stdlib=libc++
+build:c++1z --cxxopt=-std=c++1z
+build:c++1z --cxxopt=-stdlib=libc++
+
# Default paths for TF_SYSTEM_LIBS
build --define=PREFIX=/usr
build --define=LIBDIR=$(PREFIX)/lib
diff --git a/README.md b/README.md
index 4e37b23..96a8ecf 100644
--- a/README.md
+++ b/README.md
@@ -24,7 +24,8 @@
networks research. The system is general enough to be applicable in a wide
variety of other domains, as well.
-TensorFlow provides stable Python API and C APIs as well as without API backwards compatibility guarantee like C++, Go, Java, JavaScript and Swift.
+TensorFlow provides stable Python and C APIs as well as non-guaranteed backwards
+compatible API's for C++, Go, Java, JavaScript and Swift.
Keep up to date with release announcements and security updates by
subscribing to
diff --git a/WORKSPACE b/WORKSPACE
index 957b8d8..9f07b9f 100644
--- a/WORKSPACE
+++ b/WORKSPACE
@@ -29,7 +29,7 @@
bazel_toolchains_repositories()
load(
- "@io_bazel_rules_docker//container:container.bzl",
+ "@io_bazel_rules_docker//repositories:repositories.bzl",
container_repositories = "repositories",
)
@@ -43,29 +43,17 @@
# Apple and Swift rules.
http_archive(
name = "build_bazel_rules_apple",
- sha256 = "4fe4ee824200b48821730f89ff260984332dc3551db587c24691235d1d96a8a7",
- strip_prefix = "rules_apple-0.10.0",
- urls = ["https://github.com/bazelbuild/rules_apple/archive/0.10.0.tar.gz"],
-)
-http_archive(
- name = "build_bazel_rules_swift",
- sha256 = "6544ff5615febec0342de1127144d2f3e43ea80fb7f9b1ade65e6a184e39e618",
- strip_prefix = "rules_swift-0.5.0",
- urls = ["https://github.com/bazelbuild/rules_swift/archive/0.5.0.tar.gz"],
-)
-http_archive(
- name = "bazel_skylib",
- sha256 = "eb5c57e4c12e68c0c20bc774bfbc60a568e800d025557bc4ea022c6479acc867",
- strip_prefix = "bazel-skylib-0.6.0",
- urls = ["https://github.com/bazelbuild/bazel-skylib/archive/0.6.0.tar.gz"],
+ sha256 = "73b4980a318d203d3307f850e27e66ec5cc8d223147a3475a6f11597eb6438a5",
+ strip_prefix = "rules_apple-0.13.0",
+ urls = ["https://github.com/bazelbuild/rules_apple/archive/0.13.0.tar.gz"],
)
http_file(
name = "xctestrunner",
executable = 1,
- urls = ["https://github.com/google/xctestrunner/releases/download/0.2.5/ios_test_runner.par"],
+ urls = ["https://github.com/google/xctestrunner/releases/download/0.2.6/ios_test_runner.par"],
)
load("@build_bazel_rules_apple//apple:repositories.bzl", "apple_rules_dependencies")
-apple_rules_dependencies(ignore_version_differences = True)
+apple_rules_dependencies()
load("@build_bazel_rules_swift//swift:repositories.bzl", "swift_rules_dependencies")
swift_rules_dependencies()
@@ -134,4 +122,3 @@
"http://download.tensorflow.org/models/speech_commands_v0.01.zip",
],
)
-
diff --git a/configure.py b/configure.py
index e626082..14fca1f 100644
--- a/configure.py
+++ b/configure.py
@@ -256,6 +256,7 @@
"""Reset file that contains customized config settings."""
open(_TF_BAZELRC, 'w').close()
+
def cleanup_makefile():
"""Delete any leftover BUILD files from the Makefile build.
@@ -785,8 +786,7 @@
environ_cp,
var_name='GCC_HOST_COMPILER_PATH',
var_default=default_gcc_host_compiler_path,
- ask_for_var=
- 'Please specify which gcc should be used by nvcc as the host compiler.',
+ ask_for_var='Please specify which gcc should be used by nvcc as the host compiler.',
check_success=os.path.exists,
error_msg='Invalid gcc path. %s cannot be found.',
)
@@ -1237,6 +1237,7 @@
environ_cp['TF_NCCL_VERSION'] = tf_nccl_version
write_action_env_to_bazelrc('TF_NCCL_VERSION', tf_nccl_version)
+
def get_native_cuda_compute_capabilities(environ_cp):
"""Get native cuda compute capabilities.
@@ -1552,7 +1553,7 @@
# environment variables.
environ_cp = dict(os.environ)
- check_bazel_version('0.19.0', '0.21.0')
+ check_bazel_version('0.19.0', '0.22.0')
reset_tf_configure_bazelrc()
@@ -1683,8 +1684,9 @@
config_info_line('gdr', 'Build with GDR support.')
config_info_line('verbs', 'Build with libverbs support.')
config_info_line('ngraph', 'Build with Intel nGraph support.')
- config_info_line('dynamic_kernels',
- '(Experimental) Build kernels into separate shared objects.')
+ config_info_line(
+ 'dynamic_kernels',
+ '(Experimental) Build kernels into separate shared objects.')
print('Preconfigured Bazel build configs to DISABLE default on features:')
config_info_line('noaws', 'Disable AWS S3 filesystem support.')
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 5a8b97e..0b63ee4 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -95,6 +95,12 @@
)
config_setting(
+ name = "emscripten",
+ values = {"crosstool_top": "//external:android/emscripten"},
+ visibility = ["//visibility:public"],
+)
+
+config_setting(
name = "raspberry_pi_armeabi",
values = {
"crosstool_top": "@local_config_arm_compiler//:toolchain",
@@ -392,17 +398,7 @@
package_group(
name = "internal",
- packages = [
- "-//third_party/tensorflow/python/estimator",
- "//learning/deepmind/...",
- "//learning/meta_rank/...",
- "//platforms/performance/autograppler/...",
- "//tensorflow/...",
- "//tensorflow_estimator/contrib/...",
- "//tensorflow_fold/llgtm/...",
- "//tensorflow_text/...",
- "//third_party/py/tensor2tensor/...",
- ],
+ packages = ["//tensorflow/..."],
)
load(
diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD
index 6e50a09..ef7863d 100644
--- a/tensorflow/c/BUILD
+++ b/tensorflow/c/BUILD
@@ -67,6 +67,23 @@
tf_cuda_library(
name = "c_api",
+ hdrs = ["c_api.h"],
+ copts = tf_copts(),
+ visibility = ["//visibility:public"],
+ deps = [
+ ":c_api_no_xla",
+ ":c_api_internal",
+ ] + select({
+ "//tensorflow:with_xla_support": [
+ "//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/jit",
+ ],
+ "//conditions:default": [],
+ }),
+)
+
+tf_cuda_library(
+ name = "c_api_no_xla",
srcs = [
"c_api.cc",
"c_api_function.cc",
@@ -75,14 +92,12 @@
"c_api.h",
],
copts = tf_copts(),
- visibility = ["//visibility:public"],
- deps = select({
+ visibility = ["//tensorflow/c:__subpackages__"],
+ deps = [":c_api_internal"] + select({
"//tensorflow:android": [
- ":c_api_internal",
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
- ":c_api_internal",
"//tensorflow/cc/saved_model:loader_lite",
"//tensorflow/cc:gradients",
"//tensorflow/cc:ops",
@@ -97,13 +112,8 @@
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/distributed_runtime:server_lib",
+ "//tensorflow/core/kernels:logging_ops",
],
- }) + select({
- "//tensorflow:with_xla_support": [
- "//tensorflow/compiler/tf2xla:xla_compiler",
- "//tensorflow/compiler/jit",
- ],
- "//conditions:default": [],
}),
)
@@ -156,8 +166,8 @@
hdrs = ["tf_status_helper.h"],
visibility = ["//visibility:public"],
deps = [
- ":c_api",
":c_api_internal",
+ ":c_api_no_xla",
"//tensorflow/core:lib",
],
)
@@ -213,13 +223,13 @@
visibility = ["//visibility:public"],
deps = select({
"//tensorflow:android": [
- ":c_api",
+ ":c_api_no_xla",
":c_api_internal",
":tf_status_helper",
"//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
- ":c_api",
+ ":c_api_no_xla",
":c_api_internal",
":tf_status_helper",
"//tensorflow/core:framework",
@@ -300,9 +310,12 @@
"//tensorflow/core:lib",
"//tensorflow/core:math_ops_op_lib",
"//tensorflow/core:nn_ops_op_lib",
+ "//tensorflow/core:no_op_op_lib",
"//tensorflow/core:proto_text",
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:sendrecv_ops_op_lib",
"//tensorflow/core:spectral_ops_op_lib",
+ "//tensorflow/core:state_ops_op_lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/kernels:array",
@@ -343,6 +356,7 @@
srcs = ["c_api_function_test.cc"],
deps = [
":c_api",
+ ":c_api_internal",
":c_test_util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index 94d9f4a..ef22b67 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -27,6 +27,7 @@
#include "tensorflow/cc/ops/while_loop.h"
#include "tensorflow/cc/saved_model/loader.h"
#include "tensorflow/core/framework/op_gen_lib.h"
+#include "tensorflow/core/kernels/logging_ops.h"
#endif
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
@@ -1310,6 +1311,13 @@
reinterpret_cast<const DataType*>(values), num_values));
}
+void TF_SetAttrPlaceholder(TF_OperationDescription* desc, const char* attr_name,
+ const char* placeholder) {
+ tensorflow::AttrValue attr_value;
+ attr_value.set_placeholder(placeholder);
+ desc->node_builder.Attr(attr_name, attr_value);
+}
+
void TF_SetAttrFuncName(TF_OperationDescription* desc, const char* attr_name,
const char* value, size_t length) {
tensorflow::NameAttrList func_name;
@@ -2954,4 +2962,11 @@
delete server;
#endif
}
+
+void TF_RegisterLogListener(void (*listener)(const char*)) {
+#ifndef __ANDROID__
+ tensorflow::logging::RegisterListener(listener);
+#endif
+}
+
} // end extern "C"
diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h
index 8031928..88b8b49 100644
--- a/tensorflow/c/c_api.h
+++ b/tensorflow/c/c_api.h
@@ -549,6 +549,10 @@
const char* attr_name,
const TF_DataType* values,
int num_values);
+TF_CAPI_EXPORT extern void TF_SetAttrPlaceholder(TF_OperationDescription* desc,
+ const char* attr_name,
+ const char* placeholder);
+
// Set a 'func' attribute to the specified name.
// `value` must point to a string of length `length` bytes.
TF_CAPI_EXPORT extern void TF_SetAttrFuncName(TF_OperationDescription* desc,
@@ -1743,6 +1747,14 @@
// it will be stopped and joined.
TF_CAPI_EXPORT extern void TF_DeleteServer(TF_Server* server);
+// Register a listener method that processes printed messages.
+//
+// If any listeners are registered, the print operator will call all listeners
+// with the printed messages and immediately return without writing to the
+// logs.
+TF_CAPI_EXPORT extern void TF_RegisterLogListener(
+ void (*listener)(const char*));
+
#ifdef __cplusplus
} /* end extern "C" */
#endif
diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc
index 28b9f8d..1477d64 100644
--- a/tensorflow/c/c_api_function.cc
+++ b/tensorflow/c/c_api_function.cc
@@ -162,6 +162,11 @@
const std::vector<const Node*>& body_nodes,
const std::unordered_map<string, string>& tensor_renaming,
FunctionDef* fdef) {
+ std::unordered_set<string> func_attr_names;
+ for (const auto& func_attr : fdef->signature().attr()) {
+ func_attr_names.insert(func_attr.name());
+ }
+
std::vector<const Edge*> in_edges;
std::vector<const Edge*> control_edges;
for (const Node* node : body_nodes) {
@@ -243,6 +248,39 @@
if (node->op_def().is_stateful()) {
fdef->mutable_signature()->set_is_stateful(true);
}
+
+ // If this node has any attributes with placeholder value, add the
+ // attribute to FunctionDef signature.
+ for (const auto& iter : node->attrs()) {
+ if (iter.second.placeholder().empty()) {
+ continue;
+ }
+
+ // If we already added the attribute, skip it.
+ string func_attr_name = iter.second.placeholder();
+ if (func_attr_names.find(func_attr_name) != func_attr_names.end()) {
+ continue;
+ }
+
+ // This node's attribute is a placeholder value, so it does not have type
+ // information. We check node's OpDef for attribute type.
+ string node_attr_name = iter.first;
+ const OpDef_AttrDef* node_attr_def = nullptr;
+ for (const auto& node_attr : node->op_def().attr()) {
+ if (node_attr.name() == node_attr_name) {
+ node_attr_def = &node_attr;
+ }
+ }
+ if (!node_attr_def) {
+ return errors::Internal("Cannot find attr ", node_attr_name,
+ " in OpDef ", node->op_def().DebugString());
+ }
+ OpDef_AttrDef* attr_def = fdef->mutable_signature()->add_attr();
+ attr_def->set_name(func_attr_name);
+ attr_def->set_type(node_attr_def->type());
+
+ func_attr_names.insert(func_attr_name);
+ }
}
return Status::OK();
}
diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc
index 73fe737..946f8c4 100644
--- a/tensorflow/c/c_api_function_test.cc
+++ b/tensorflow/c/c_api_function_test.cc
@@ -15,6 +15,7 @@
#include "tensorflow/c/c_api.h"
+#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/c_test_util.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
@@ -1230,6 +1231,53 @@
ASSERT_NE(*func, nullptr);
}
+REGISTER_OP("CustomOp")
+ .Output("output: float32")
+ .Attr("index: int")
+ .SetShapeFn(tensorflow::shape_inference::UnknownShape);
+
+void NodeWithPlaceholderAttrHelper(TF_Graph* graph, TF_Status* s,
+ const char* name, const char* placeholder,
+ TF_Operation** op) {
+ TF_OperationDescription* desc = TF_NewOperation(graph, "CustomOp", name);
+ TF_SetAttrPlaceholder(desc, "index", placeholder);
+ *op = TF_FinishOperation(desc, s);
+ ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
+ ASSERT_NE(*op, nullptr);
+}
+
+TEST_F(CApiFunctionTest, GraphToFunctionDefWithPlaceholderAttr) {
+ std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> func_graph(
+ TF_NewGraph(), TF_DeleteGraph);
+ std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> s(TF_NewStatus(),
+ TF_DeleteStatus);
+
+ TF_Operation *node1, *node2, *node3;
+ NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node1", "v1",
+ &node1);
+ NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node2", "v1",
+ &node2);
+ NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node3", "v2",
+ &node3);
+
+ TF_Output inputs[] = {};
+ TF_Output outputs[] = {{node1, 0}, {node2, 0}, {node3, 0}};
+ func_ = TF_GraphToFunction(
+ func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1,
+ /*opers=*/nullptr, 0, inputs, 3, outputs,
+ /*output_names=*/nullptr,
+ /*opts=*/nullptr, /*description=*/nullptr, s.get());
+ ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
+ ASSERT_NE(func_, nullptr);
+
+ // Verify that FunctionDef has 2 attributes, "v1" and "v2".
+ ASSERT_EQ(func_->fdef.signature().attr().size(), 2);
+ EXPECT_EQ(func_->fdef.signature().attr(0).name(), "v1");
+ EXPECT_EQ(func_->fdef.signature().attr(0).type(), "int");
+ EXPECT_EQ(func_->fdef.signature().attr(1).name(), "v2");
+ EXPECT_EQ(func_->fdef.signature().attr(1).type(), "int");
+}
+
TEST_F(CApiFunctionTest, SetGradientAndRun) {
// Define the function and its grad
DefineFunction(func_name_, &func_);
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index 04dfefa..257be63 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -70,7 +70,7 @@
"//tensorflow/core/distributed_runtime:remote_device",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:worker_env",
- "//tensorflow/core/profiler/lib:eager_profiler",
+ "//tensorflow/core/profiler/lib:profiler_session",
"//tensorflow/core:gpu_runtime",
],
)
@@ -110,7 +110,7 @@
"//tensorflow/core/distributed_runtime/rpc:grpc_worker_service",
"//tensorflow/core/distributed_runtime/rpc:rpc_rendezvous_mgr",
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
- "//tensorflow/core/profiler/lib:eager_profiler",
+ "//tensorflow/core/profiler/lib:profiler_session",
],
)
diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h
index 3b9e681..b70c0f1 100644
--- a/tensorflow/c/eager/c_api_internal.h
+++ b/tensorflow/c/eager/c_api_internal.h
@@ -52,7 +52,7 @@
#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
-#include "tensorflow/core/profiler/lib/eager_profiler.h"
+#include "tensorflow/core/profiler/lib/profiler_session.h"
#include "tensorflow/core/public/version.h"
struct TFE_ContextOptions {
@@ -109,9 +109,9 @@
struct TFE_Profiler {
TFE_Profiler(TFE_Context* ctx)
- : profiler(tensorflow::EagerProfiler::Create(&ctx->context)) {}
+ : profiler(tensorflow::ProfilerSession::Create(&ctx->context)) {}
- std::unique_ptr<tensorflow::EagerProfiler> profiler;
+ std::unique_ptr<tensorflow::ProfilerSession> profiler;
};
namespace tensorflow {
diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc
index 10f7abf..66260fc 100644
--- a/tensorflow/cc/saved_model/loader.cc
+++ b/tensorflow/cc/saved_model/loader.cc
@@ -26,7 +26,6 @@
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/protobuf_internal.h"
-#include "tensorflow/core/protobuf/saved_model.pb.h"
#include "tensorflow/core/protobuf/saver.pb.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/session_options.h"
diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc
index 9fc223b..0e46a9f 100644
--- a/tensorflow/compiler/aot/compile.cc
+++ b/tensorflow/compiler/aot/compile.cc
@@ -108,10 +108,13 @@
computation.Snapshot());
// Serialize the HloSnapshot deterministically so that all the outputs of a
// tf_library genrule are deterministic.
- string proto;
- TF_RET_CHECK(SerializeToStringDeterministic(*module, &proto));
+ const size_t size = module->ByteSizeLong();
+ auto serialized = absl::make_unique<char[]>(size);
+ TF_RET_CHECK(
+ SerializeToBufferDeterministic(*module, serialized.get(), size));
TF_RETURN_IF_ERROR(
- WriteStringToFile(Env::Default(), flags.out_session_module, proto));
+ WriteStringToFile(Env::Default(), flags.out_session_module,
+ absl::string_view(serialized.get(), size)));
}
xla::cpu::CpuAotCompilationOptions aot_opts(
flags.target_triple, flags.target_cpu, flags.target_features,
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 55e2e6d..3cae081 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -179,14 +179,18 @@
"//tensorflow/core:control_flow_ops_op_lib",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:functional_ops_op_lib",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:math_ops_op_lib",
"//tensorflow/core:nn_ops_op_lib",
+ "//tensorflow/core:no_op_op_lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:resource_variable_ops_op_lib",
+ "//tensorflow/core:sendrecv_ops_op_lib",
+ "//tensorflow/core:state_ops_op_lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core/kernels:cast_op",
"//tensorflow/core/kernels:constant_op",
diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc
index 9f40426..285b1ef 100644
--- a/tensorflow/compiler/jit/build_xla_ops_pass.cc
+++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc
@@ -115,6 +115,13 @@
return;
}
+ if (ctrl_edges.size() == 1 && ctrl_edges.front()->dst()->IsSink()) {
+ // Avoid creating a Merge node if we can just add an edge to _SINK
+ // instead.
+ s.graph()->AddControlEdge(new_node, s.graph()->sink_node());
+ return;
+ }
+
// We can't merge control edges directly so we instead first "convert" them to
// normal values that can be merged, merge the values and then "convert" the
// merged value back into control.
diff --git a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc
index 390ffa6..c14c746 100644
--- a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc
+++ b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc
@@ -68,6 +68,8 @@
}
}
+ FixupSourceAndSinkEdges(graph.get());
+
GraphOptimizationPassOptions opt_options;
opt_options.graph = &graph;
BuildXlaOpsPass pass(/*enable_lazy_compilation=*/true);
@@ -223,5 +225,23 @@
ASSERT_NE(write_op_new, nullptr);
EXPECT_THAT(write_op_new, assign_var);
}
+
+TEST_F(BuildXlaOpsTest, NoExtraMergeForEdgeToSink) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ FunctionDefLibrary flib_def =
+ CreateFunctionDefLibWithConstFunction("cluster_0");
+ TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
+ Node* call;
+ TF_ASSERT_OK(MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", &call));
+
+ std::unique_ptr<Graph> graph;
+ TF_ASSERT_OK(BuildXlaOps(root, &graph));
+
+ Node* sink_node = graph->sink_node();
+ EXPECT_THAT(sink_node, NodeWith(CtrlDeps(NodeWith(Op("_XlaRun")),
+ NodeWith(Op("cluster_0")),
+ NodeWith(Op("NoOp")))));
+}
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc
index 0ef0d3d..4397eea 100644
--- a/tensorflow/compiler/jit/deadness_analysis.cc
+++ b/tensorflow/compiler/jit/deadness_analysis.cc
@@ -113,7 +113,11 @@
enum class Kind { kAnd, kOr, kNot, kAndRecurrence, kSymbol };
virtual string ToString() const = 0;
- int64 hash() const { return hash_; }
+
+ // An ID assigned to the Predicate at construction time. Conceptually like a
+ // pointer, except that it is stable across runs.
+ int64 id() const { return id_; }
+
virtual absl::Span<Predicate* const> GetOperands() const = 0;
virtual Kind kind() const = 0;
@@ -126,29 +130,19 @@
static void Visit(Predicate* p, const FunctionTy& func);
protected:
- explicit Predicate(int64 hash) : hash_(hash) {}
+ explicit Predicate(int64 id) : id_(id) {}
private:
- const int64 hash_;
+ const int64 id_;
TF_DISALLOW_COPY_AND_ASSIGN(Predicate);
};
-int64 HashPredicateSequence(Predicate::Kind kind,
- absl::Span<Predicate* const> preds) {
- int64 hash = ::tensorflow::hash<Predicate::Kind>()(kind);
- for (Predicate* pred : preds) {
- hash = Hash64Combine(hash, pred->hash());
- }
- return hash;
-}
-
// Represents a logical conjunction of a set of predicates.
class AndPredicate : public Predicate {
public:
- explicit AndPredicate(std::vector<Predicate*> operands)
- : Predicate(HashPredicateSequence(Kind::kAnd, operands)),
- operands_(std::move(operands)) {}
+ explicit AndPredicate(int64 id, std::vector<Predicate*> operands)
+ : Predicate(id), operands_(std::move(operands)) {}
string ToString() const override {
if (operands().empty()) {
@@ -177,9 +171,8 @@
// Represents a logical disjunction of a set of predicates.
class OrPredicate : public Predicate {
public:
- explicit OrPredicate(std::vector<Predicate*> operands)
- : Predicate(HashPredicateSequence(Kind::kOr, operands)),
- operands_(std::move(operands)) {}
+ explicit OrPredicate(int64 id, std::vector<Predicate*> operands)
+ : Predicate(id), operands_(std::move(operands)) {}
string ToString() const override {
if (operands().empty()) {
@@ -207,9 +200,8 @@
// Represents a logical negation of a set of predicates.
class NotPredicate : public Predicate {
public:
- explicit NotPredicate(Predicate* operand)
- : Predicate(HashPredicateSequence(Kind::kNot, {operand})),
- operands_({operand}) {}
+ explicit NotPredicate(int64 id, Predicate* operand)
+ : Predicate(id), operands_({operand}) {}
string ToString() const override {
return absl::StrCat("~", operand()->ToString());
@@ -246,11 +238,9 @@
// iterations).
class AndRecurrencePredicate : public Predicate {
public:
- explicit AndRecurrencePredicate(Predicate* start, Predicate* step,
+ explicit AndRecurrencePredicate(int64 id, Predicate* start, Predicate* step,
std::vector<string> frame)
- : Predicate(Hash(start, step, frame)),
- operands_({start, step}),
- frame_(std::move(frame)) {}
+ : Predicate(id), operands_({start, step}), frame_(std::move(frame)) {}
Predicate* start() const { return operands_[0]; }
Predicate* step() const { return operands_[1]; }
@@ -270,16 +260,6 @@
private:
std::array<Predicate*, 2> operands_;
std::vector<string> frame_;
-
- static int64 Hash(Predicate* start, Predicate* step,
- const std::vector<string>& frame) {
- uint64 frame_hash = 0;
- for (const string& sub_frame : frame) {
- frame_hash = Hash64Combine(Hash64(sub_frame), frame_hash);
- }
- return Hash64Combine(
- HashPredicateSequence(Kind::kAndRecurrence, {start, step}), frame_hash);
- }
};
// Represents an uninterpreted symbol in a logical predicate.
@@ -289,8 +269,8 @@
// symbols.
class SymbolPredicate : public Predicate {
public:
- explicit SymbolPredicate(TensorId tensor_id, bool must_be_true)
- : Predicate(Hash(tensor_id, must_be_true)),
+ explicit SymbolPredicate(int64 id, TensorId tensor_id, bool must_be_true)
+ : Predicate(id),
tensor_id_(std::move(tensor_id)),
must_be_true_(must_be_true) {}
@@ -313,13 +293,6 @@
private:
TensorId tensor_id_;
bool must_be_true_;
-
- static int64 Hash(const TensorId tensor_id, bool must_be_true) {
- return Hash64Combine(
- ::tensorflow::hash<bool>()(must_be_true),
- Hash64Combine(::tensorflow::hash<Predicate::Kind>()(Kind::kSymbol),
- TensorId::Hasher{}(tensor_id)));
- }
};
template <typename FunctionTy>
@@ -477,8 +450,11 @@
template <typename PredicateT, typename... Args>
std::unique_ptr<Predicate> Make(Args&&... args) {
+ // If we ever expose the Predicate class outside this .cc file then we may
+ // want to make this hard to misuse (by accidentally passing in an arbitrary
+ // integer to the Predicate constructor for instance).
return std::unique_ptr<PredicateT>(
- new PredicateT(std::forward<Args>(args)...));
+ new PredicateT(id_counter_++, std::forward<Args>(args)...));
}
Predicate* MakeAndOrImpl(absl::Span<Predicate* const> operands, bool is_and);
@@ -559,6 +535,7 @@
absl::flat_hash_map<SignatureForSymbol, std::unique_ptr<Predicate>,
HashSignatureForSymbol>
interned_symbol_instances_;
+ int64 id_counter_ = 0;
int stack_depth_ = 0;
};
@@ -566,7 +543,7 @@
std::vector<Predicate*> simplified_ops, Predicate::Kind pred_kind) {
std::stable_sort(
simplified_ops.begin(), simplified_ops.end(),
- [](Predicate* a, Predicate* b) { return a->hash() < b->hash(); });
+ [](Predicate* a, Predicate* b) { return a->id() < b->id(); });
auto it = interned_and_or_instances_.find({pred_kind, simplified_ops});
if (it != interned_and_or_instances_.end()) {
diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc
index 16ee8f8..38a5118 100644
--- a/tensorflow/compiler/jit/deadness_analysis_test.cc
+++ b/tensorflow/compiler/jit/deadness_analysis_test.cc
@@ -521,7 +521,7 @@
EXPECT_EQ(predicate_map[ControlOutputFor(iv2)],
"{#true,&,*iv2/cond:0}<fr0>");
EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
- "({#true,&,*iv1/cond:0}<fr0> & {#true,&,*iv0/cond:0}<fr0>)");
+ "({#true,&,*iv0/cond:0}<fr0> & {#true,&,*iv1/cond:0}<fr0>)");
EXPECT_EQ(predicate_map[ControlOutputFor(add1)],
"({#true,&,*iv1/cond:0}<fr0> & {#true,&,*iv2/cond:0}<fr0>)");
}
@@ -553,11 +553,11 @@
EXPECT_EQ(predicate_map[ControlOutputFor(iv.induction_var)],
"{#true,&,*iv0/cond:0}<loop>");
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv0)],
- "{#true,&,(*iv0/cond:0 & iv0/iv:0)}<loop>");
+ "{#true,&,(iv0/iv:0 & *iv0/cond:0)}<loop>");
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_iv1)],
- "{#true,&,(*iv0/cond:0 & iv0/iv:0)}<loop>");
+ "{#true,&,(iv0/iv:0 & *iv0/cond:0)}<loop>");
EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
- "{#true,&,(*iv0/cond:0 & iv0/iv:0)}<loop>");
+ "{#true,&,(iv0/iv:0 & *iv0/cond:0)}<loop>");
}
}
@@ -643,22 +643,23 @@
EXPECT_EQ(predicate_map[ControlOutputFor(iv_outer.induction_var)],
"{#true,&,*iv_outer/cond:0}<outer_loop>");
EXPECT_EQ(predicate_map[ControlOutputFor(iv_inner.induction_var)],
- "{({#true,&,*iv_outer/cond:0}<outer_loop> & "
- "*iv_outer/cond:0),&,*iv_inner/cond:0}<inner_loop;outer_loop>");
+ "{(*iv_outer/cond:0 & "
+ "{#true,&,*iv_outer/cond:0}<outer_loop>),&,*iv_inner/"
+ "cond:0}<inner_loop;outer_loop>");
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv0)],
"{{#true,&,(iv_outer/iv:0 & "
- "*iv_outer/cond:0)}<outer_loop>,&,(*iv_inner/cond:0 & "
- "iv_inner/iv:0)}<inner_loop;outer_loop>");
+ "*iv_outer/cond:0)}<outer_loop>,&,(iv_inner/iv:0 & "
+ "*iv_inner/cond:0)}<inner_loop;outer_loop>");
EXPECT_EQ(predicate_map[ControlOutputFor(dependent_inner_iv1)],
"{{#true,&,(iv_outer/iv:0 & "
- "*iv_outer/cond:0)}<outer_loop>,&,(*iv_inner/cond:0 & "
- "iv_inner/iv:0)}<inner_loop;outer_loop>");
+ "*iv_outer/cond:0)}<outer_loop>,&,(iv_inner/iv:0 & "
+ "*iv_inner/cond:0)}<inner_loop;outer_loop>");
EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
"{{#true,&,(iv_outer/iv:0 & "
- "*iv_outer/cond:0)}<outer_loop>,&,(*iv_inner/cond:0 & "
- "iv_inner/iv:0)}<inner_loop;outer_loop>");
+ "*iv_outer/cond:0)}<outer_loop>,&,(iv_inner/iv:0 & "
+ "*iv_inner/cond:0)}<inner_loop;outer_loop>");
}
}
@@ -702,20 +703,21 @@
EXPECT_EQ(predicate_map[ControlOutputFor(outer_iv[0])],
"{#true,&,*iv_outer/cond:0}<outer_loop>");
EXPECT_EQ(predicate_map[ControlOutputFor(inner_iv[0])],
- "{({#true,&,*iv_outer/cond:0}<outer_loop> & "
- "*iv_outer/cond:0),&,*iv_inner/cond:0}<inner_loop;outer_loop>");
+ "{(*iv_outer/cond:0 & "
+ "{#true,&,*iv_outer/cond:0}<outer_loop>),&,*iv_inner/"
+ "cond:0}<inner_loop;outer_loop>");
EXPECT_EQ(predicate_map[ControlOutputFor(outer_iv[1])],
"{#true,&,*iv_outer/cond_1:0}<outer_loop>");
- EXPECT_EQ(
- predicate_map[ControlOutputFor(inner_iv[1])],
- "{({#true,&,*iv_outer/cond_1:0}<outer_loop> & "
- "*iv_outer/cond_1:0),&,*iv_inner/cond_1:0}<inner_loop;outer_loop>");
- EXPECT_EQ(
- predicate_map[ControlOutputFor(add0)],
- "({({#true,&,*iv_outer/cond:0}<outer_loop> & "
- "*iv_outer/cond:0),&,*iv_inner/cond:0}<inner_loop;outer_loop> & "
- "{({#true,&,*iv_outer/cond_1:0}<outer_loop> & "
- "*iv_outer/cond_1:0),&,*iv_inner/cond_1:0}<inner_loop;outer_loop>)");
+ EXPECT_EQ(predicate_map[ControlOutputFor(inner_iv[1])],
+ "{(*iv_outer/cond_1:0 & "
+ "{#true,&,*iv_outer/cond_1:0}<outer_loop>),&,*iv_inner/"
+ "cond_1:0}<inner_loop;outer_loop>");
+ EXPECT_EQ(predicate_map[ControlOutputFor(add0)],
+ "({(*iv_outer/cond:0 & "
+ "{#true,&,*iv_outer/cond:0}<outer_loop>),&,*iv_inner/"
+ "cond:0}<inner_loop;outer_loop> & {(*iv_outer/cond_1:0 & "
+ "{#true,&,*iv_outer/cond_1:0}<outer_loop>),&,*iv_inner/"
+ "cond_1:0}<inner_loop;outer_loop>)");
}
}
diff --git a/tensorflow/compiler/jit/encapsulate_util_test.cc b/tensorflow/compiler/jit/encapsulate_util_test.cc
index 3bb979e..6d16612 100644
--- a/tensorflow/compiler/jit/encapsulate_util_test.cc
+++ b/tensorflow/compiler/jit/encapsulate_util_test.cc
@@ -21,7 +21,6 @@
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
-#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
index ec745cd..c9ae717 100644
--- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
@@ -173,9 +173,10 @@
// Nondeterminism in serialization would not lead to incorrect results, but
// may cause spurious cache misses. DeterministicSerialization is a
// best-effort deterministic serialization.
- string serialized;
- TF_RET_CHECK(SerializeToStringDeterministic(gdef, &serialized));
- uint64 fingerprint = Fingerprint64(serialized);
+ const size_t size = gdef.ByteSizeLong();
+ auto serialized = absl::make_unique<char[]>(size);
+ TF_RET_CHECK(SerializeToBufferDeterministic(gdef, serialized.get(), size));
+ uint64 fingerprint = Fingerprint64(absl::string_view(serialized.get(), size));
LOG(INFO) << "Subgraph fingerprint:" << fingerprint;
call_def->set_op(absl::StrCat(call_def->op(), "_", fingerprint));
return Status::OK();
diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc
index ebfffc3..5287fd1 100644
--- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc
+++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass.cc
@@ -247,6 +247,7 @@
.NewSubScope(absl::StrCat(slice->name(), "/static_shaped_slice"));
Scope host_scope = main_scope.WithAssignedDevice(host_name);
+ // In the future we may want to be clever here and avoid the extra Cast ops.
SliceInputs slice_inputs_int64 =
MakeSliceIndexAndSizeInt64(host_scope, slice_inputs);
@@ -312,9 +313,9 @@
return Status::OK();
}
-// Return true if `n` is a slice we can rewrite to have a static shape
+// Return true if `n` is a slice we should rewrite to have a static shape
// (i.e. have the output shape only depend on the "size" input).
-xla::StatusOr<bool> IsRewritableSlice(Node* n) {
+xla::StatusOr<bool> ShouldRewriteSlice(Node* n) {
if (n->type_string() != "Slice") {
return false;
}
@@ -332,14 +333,20 @@
// If slice_size[i] < -1 for any i then executing the slice will throw an
// error, and we don't do anything here.
- return absl::c_all_of(slice_inputs->size_as_vector,
- [](int64 size_i) { return size_i >= -1; });
+ bool slice_size_has_error = absl::c_all_of(
+ slice_inputs->size_as_vector, [](int64 size_i) { return size_i >= -1; });
+ if (!slice_size_has_error) {
+ return false;
+ }
+
+ // No point in rewriting slices that have both size and begin as constants.
+ return !slice_inputs->begin.node()->IsConstant();
}
Status FindAndRewriteSlices(Graph* g, bool* changed) {
std::vector<Node*> slices_to_rewrite;
for (Node* n : g->nodes()) {
- TF_ASSIGN_OR_RETURN(bool is_rewritable, IsRewritableSlice(n));
+ TF_ASSIGN_OR_RETURN(bool is_rewritable, ShouldRewriteSlice(n));
if (is_rewritable) {
slices_to_rewrite.push_back(n);
}
diff --git a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc
index 32e3021..2add2c1 100644
--- a/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc
+++ b/tensorflow/compiler/jit/increase_dynamism_for_auto_jit_pass_test.cc
@@ -432,5 +432,26 @@
Name("dependency")))));
}
+TEST(SliceToDynamicSliceRewriteTest, DontRewriteSliceWithConstBegin) {
+ Scope root = Scope::NewRootScope()
+ .ExitOnError()
+ .WithAssignedDevice(kDeviceName)
+ .WithXlaCluster("cluster_0");
+
+ Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
+ Output begin = ops::Const(root.WithOpName("begin"), {10, 10});
+ Output size = ops::Const(root.WithOpName("size"), {-1, 500});
+ Output slice = ops::Slice(root.WithOpName("slice"), input, begin, size);
+
+ std::unique_ptr<Graph> result;
+ TF_ASSERT_OK(IncreaseDynamismForAutoJit(root, &result));
+
+ Node* slice_node = testing::FindNodeByName(result.get(), "slice");
+ EXPECT_THAT(slice_node,
+ NodeWith(Op("Slice"), Inputs(Out(NodeWith(Op("Placeholder"))),
+ Out(NodeWith(Op("Const"))),
+ Out(NodeWith(Op("Const"))))));
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD
index 0583774..bab824c 100644
--- a/tensorflow/compiler/jit/kernels/BUILD
+++ b/tensorflow/compiler/jit/kernels/BUILD
@@ -25,6 +25,7 @@
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
+ "//tensorflow/core:state_ops_op_lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core/kernels:variable_ops",
"@com_google_absl//absl/container:flat_hash_map",
diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc
index e2397f6..c67b4f1 100644
--- a/tensorflow/compiler/jit/xla_device.cc
+++ b/tensorflow/compiler/jit/xla_device.cc
@@ -479,7 +479,7 @@
return sync_on_completion_;
}
-Status XlaDevice::CurrentStatus() {
+Status XlaDevice::RefreshStatus() {
std::shared_ptr<se::Stream> stream;
{
mutex_lock lock(mu_);
@@ -488,7 +488,8 @@
if (!stream) {
return Status::OK();
}
- return stream->ok() ? Status::OK() : errors::Internal("XlaDevice is not OK.");
+ // Stream status is XlaDevice status, no extra operations needed.
+ return stream->RefreshStatus();
}
XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h
index e35a1c7..5fe1290 100644
--- a/tensorflow/compiler/jit/xla_device.h
+++ b/tensorflow/compiler/jit/xla_device.h
@@ -169,10 +169,9 @@
// Instructs this XlaDevice to return 'sync_on_completion' for
// AllowsSyncOnCompletion().
void SetAllowsSyncOnCompletion(bool sync_on_completion) LOCKS_EXCLUDED(mu_);
-
bool AllowsSyncOnCompletion() const override LOCKS_EXCLUDED(mu_);
- Status CurrentStatus() override LOCKS_EXCLUDED(mu_);
+ Status RefreshStatus() override LOCKS_EXCLUDED(mu_);
private:
xla::LocalClient* client() const;
diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h
index 927f983..f201f62 100644
--- a/tensorflow/compiler/jit/xla_device_ops.h
+++ b/tensorflow/compiler/jit/xla_device_ops.h
@@ -241,6 +241,8 @@
data::AnonymousIteratorHandleOp); \
REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE), \
data::IteratorGetNextOp); \
+ REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE), \
+ data::IteratorGetNextAsOptionalOp); \
REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE), \
data::IteratorGetNextSyncOp); \
REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle") \
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 139c927..9b6ca40 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -72,7 +72,7 @@
tf_xla_py_test(
name = "adadelta_test",
- size = "large",
+ size = "medium",
srcs = ["adadelta_test.py"],
deps = [
":xla_test",
diff --git a/tensorflow/compiler/tests/adadelta_test.py b/tensorflow/compiler/tests/adadelta_test.py
index b7b7fda..6cf16cc 100644
--- a/tensorflow/compiler/tests/adadelta_test.py
+++ b/tensorflow/compiler/tests/adadelta_test.py
@@ -32,10 +32,18 @@
def testBasic(self):
num_updates = 4 # number of ADADELTA steps to perform
+ if "CPU" in self.device:
+ # To avoid timeout on CPU.
+ all_grad = [0.2, 0.01]
+ all_lr = [1.0, 0.1]
+ else:
+ all_grad = [0.2, 0.1, 0.01]
+ all_lr = [1.0, 0.5, 0.1]
+
for dtype in self.float_types:
with self.cached_session(), self.test_scope():
- for grad in [0.2, 0.1, 0.01]:
- for lr in [1.0, 0.5, 0.1]:
+ for grad in all_grad:
+ for lr in all_lr:
var0_init = [1.0, 2.0]
var1_init = [3.0, 4.0]
var0 = resource_variable_ops.ResourceVariable(
diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py
index c9fce39..632eccb 100644
--- a/tensorflow/compiler/tests/eager_test.py
+++ b/tensorflow/compiler/tests/eager_test.py
@@ -34,6 +34,7 @@
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_random_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
@@ -607,6 +608,21 @@
self.assertEqual(11.0, plus_one.numpy())
self.assertEqual(9.0, minus_one.numpy())
+ def testScanInDefun(self):
+ with self.test_scope():
+ elems = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], name='data')
+ v = constant_op.constant(2.0, name='v')
+
+ @def_function.function
+ def f(y):
+ # pylint: disable=unnecessary-lambda
+ return functional_ops.scan(
+ lambda a, x: math_ops.multiply(a, x), y, initializer=v)
+ # pylint: enable=unnecessary-lambda
+
+ r = f(elems)
+ self.assertAllEqual([2., 4., 12., 48., 240., 1440.], self.evaluate(r))
+
class ExcessivePaddingTest(xla_test.XLATestCase):
"""Test that eager execution works with TPU flattened tensors.
diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py
index 374942a..56a8e1b 100644
--- a/tensorflow/compiler/tests/fused_batchnorm_test.py
+++ b/tensorflow/compiler/tests/fused_batchnorm_test.py
@@ -191,6 +191,20 @@
mean_val = np.random.random_sample(scale_shape).astype(np.float32)
var_val = np.random.random_sample(scale_shape).astype(np.float32)
epsilon = 0.001
+
+ # The TensorFlow FusedBatchNormGrad training operation takes two inputs with
+ # implementation defined values. In theory the only correct value these
+ # inputs are the corresponding reserve_space_{1|2} outputs from the
+ # FusedBatchNorm training operation. However, in practice, we rely on the
+ # first one being mean on {C|G}PU, and the second one being variance on CPU
+ # and inverse(sqrt(variance + epsilon)) on GPU (we test this assumption
+ # separately).
+ reserve_space_1_val = mean_val
+ if self.device == "XLA_GPU":
+ reserve_space_2_val = np.reciprocal(np.sqrt(var_val + epsilon))
+ else:
+ reserve_space_2_val = var_val
+
data_format_src = "NHWC"
grad_x_ref, grad_scale_ref, grad_offset_ref = self._reference_grad(
x_val, grad_val, scale_val, mean_val, var_val, epsilon, data_format_src)
@@ -207,18 +221,26 @@
np.float32, shape=x_val_converted.shape, name="grad")
x = array_ops.placeholder(
np.float32, shape=x_val_converted.shape, name="x")
- mean = array_ops.placeholder(np.float32, shape=scale_shape, name="mean")
- var = array_ops.placeholder(np.float32, shape=scale_shape, name="var")
+ reserve_space_1 = array_ops.placeholder(
+ np.float32, shape=scale_shape, name="reserve_space_1")
+ reserve_space_2 = array_ops.placeholder(
+ np.float32, shape=scale_shape, name="reserve_space_2")
scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
grad_x, grad_scale, grad_offset, _, _ = gen_nn_ops.fused_batch_norm_grad(
- grad, x, scale, mean, var, data_format=data_format, is_training=True)
+ grad,
+ x,
+ scale,
+ reserve_space_1,
+ reserve_space_2,
+ data_format=data_format,
+ is_training=True)
grad_x_val, grad_scale_val, grad_offset_val = sess.run(
[grad_x, grad_scale, grad_offset], {
grad: grad_val_converted,
x: x_val_converted,
- mean: mean_val,
- var: var_val,
+ reserve_space_1: reserve_space_1_val,
+ reserve_space_2: reserve_space_2_val,
scale: scale_val
})
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index 083e2e5..978ed66 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -392,6 +392,11 @@
[[-0.66666669, -0.5, 0, 0.5, 0.66666669]], dtype=dtype))
self._assertOpOutputMatchesExpected(
+ math_ops.sign,
+ np.array([[-2.0, -1.0, -0.0, +0.0, 1.0, 2.0]], dtype=dtype),
+ expected=np.array([[-1.0, -1.0, -0.0, +0.0, 1.0, 1.0]], dtype=dtype))
+
+ self._assertOpOutputMatchesExpected(
math_ops.is_finite,
np.array(
[[42, float("inf"), -123], [float("nan"), 0, -0.0]], dtype=dtype),
@@ -743,6 +748,10 @@
np.array(
[[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], dtype=dtype),
expected=np.array([[0, 0, 0, 0, 0, 0, 0, 0, 1]], dtype=np.bool))
+ self._assertOpOutputMatchesExpected(
+ math_ops.sign,
+ np.array([[np.nan]], dtype=dtype),
+ expected=np.array([[0.0]], dtype=dtype))
def testLogicalOps(self):
self._assertOpOutputMatchesExpected(
diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD
index 35d577e..00d3c8c 100644
--- a/tensorflow/compiler/tf2tensorrt/BUILD
+++ b/tensorflow/compiler/tf2tensorrt/BUILD
@@ -323,6 +323,7 @@
":trt_plugins",
"@com_google_googletest//:gtest",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:ops",
"//tensorflow/cc:scope",
@@ -332,6 +333,7 @@
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:tensor_testutil",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc
index 8b902dd..f3db425 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc
@@ -83,7 +83,8 @@
}
TrtCandidateSelector::TrtCandidateSelector(
- const grappler::GraphProperties& graph_properties, int precision_mode)
+ const grappler::GraphProperties& graph_properties,
+ TrtPrecisionMode precision_mode)
: graph_properties_(graph_properties), precision_mode_(precision_mode) {}
Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) {
@@ -98,6 +99,7 @@
"ConcatV2",
"Const",
"Conv2D",
+ "Conv2DBackpropInput",
"DepthwiseConv2dNative",
"Div",
"Exp",
@@ -105,6 +107,7 @@
"FusedBatchNorm",
"FusedBatchNormV2",
"Identity",
+ "LeakyRelu",
"Log",
"MatMul",
"Max",
@@ -149,7 +152,8 @@
// In INT8 mode, we will always apply the quantization ranges provided by
// these ops to the relevant tensors. This happens regardless of the value of
// use_calibration.
- if (precision_mode_ == INT8MODE && quantize_ops.count(node->type_string())) {
+ if (precision_mode_ == TrtPrecisionMode::INT8 &&
+ quantize_ops.count(node->type_string())) {
is_supported_op_type = true;
}
// LINT.ThenChange(//tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc)
@@ -239,7 +243,7 @@
const tensorflow::GraphDef& graph_def,
const std::vector<string>& output_names, size_t max_batch_size,
size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def,
- int precision_mode, int minimum_segment_size, bool is_dyn_op,
+ TrtPrecisionMode precision_mode, int minimum_segment_size, bool is_dyn_op,
int max_cached_engines, std::vector<int> cached_engine_batches,
bool use_calibration) {
// Create GrapplerItem.
@@ -299,7 +303,7 @@
parameters["max_batch_size"].set_i(max_batch_size);
parameters["is_dynamic_op"].set_b(is_dyn_op);
parameters["max_workspace_size_bytes"].set_i(max_workspace_size_bytes);
- TF_RETURN_IF_ERROR(GetPrecisionModeName(
+ TF_RETURN_IF_ERROR(TrtPrecisionModeToName(
precision_mode, parameters["precision_mode"].mutable_s()));
parameters["maximum_cached_engines"].set_i(max_cached_engines);
if (!cached_engine_batches.empty()) {
@@ -638,7 +642,7 @@
}
const bool calibrate_int8 =
- (info.precision_mode == INT8MODE && info.use_calibration);
+ (info.precision_mode == TrtPrecisionMode::INT8 && info.use_calibration);
// Build the engine and get its serialized representation.
string segment_string;
if (info.engine_type == EngineInfo::EngineType::TRTStatic || calibrate_int8) {
@@ -651,7 +655,8 @@
TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
// TODO(sami): What happens if 1st dim is not batch?
TF_RETURN_IF_ERROR(ConvertGraphDefToEngine(
- info.segment_graph_def, calibrate_int8 ? FP32MODE : info.precision_mode,
+ info.segment_graph_def,
+ calibrate_int8 ? TrtPrecisionMode::FP32 : info.precision_mode,
max_batch_size, info.max_workspace_size_bytes, input_shapes,
&trt_logger, alloc, /*calibrator=*/nullptr, &engine,
info.use_calibration,
@@ -668,7 +673,7 @@
}
string prec_string;
- TF_RETURN_IF_ERROR(GetPrecisionModeName(info.precision_mode, &prec_string));
+ TF_RETURN_IF_ERROR(TrtPrecisionModeToName(info.precision_mode, &prec_string));
tensorflow::NodeDefBuilder node_builder(info.engine_name, "TRTEngineOp");
if (!info.device.empty()) node_builder.Device(info.device);
if (VLOG_IS_ON(1)) {
@@ -849,6 +854,12 @@
auto native_segment = fdeflib.add_function();
TF_RETURN_IF_ERROR(tensorflow::GraphToFunctionDef(
sgraph, StrCat(engine_name, "_native_segment"), native_segment));
+ // Set kIntsonDeviceAttr to true so that all TRTEngineOp outputs are always on
+ // a GPU device as expected. Otherwise, some of the tensors of type DT_INT32
+ // would be on host if the op generating the tensor has host memory tag set.
+ (*native_segment
+ ->mutable_attr())[FunctionLibraryDefinition::kIntsOnDeviceAttr]
+ .set_b(true);
if (VLOG_IS_ON(7)) {
VLOG(7) << engine_name << " Function_Def ";
VLOG(7) << native_segment->DebugString();
@@ -970,7 +981,8 @@
continue;
}
curr_engine.precision_mode = params.precision_mode;
- if (params.use_calibration && params.precision_mode != INT8MODE) {
+ if (params.use_calibration &&
+ params.precision_mode != TrtPrecisionMode::INT8) {
return errors::InvalidArgument(
"Calibration with FP32 or FP16 is not supported.");
}
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h
index fb82a430..95cf022 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.h
@@ -36,7 +36,7 @@
class TrtCandidateSelector {
public:
TrtCandidateSelector(const grappler::GraphProperties& graph_properties,
- int precision_mode);
+ TrtPrecisionMode precision_mode);
// Returns OK iff 'node' is a TF-TRT conversion candidate, which will be added
// to TRT subgraph and later converted into TRT engine.
@@ -52,7 +52,7 @@
const grappler::GraphProperties& graph_properties_;
// Quantization ops are only converted when using quantized precisions.
- const int precision_mode_;
+ const TrtPrecisionMode precision_mode_;
};
struct ConversionParams {
@@ -61,7 +61,7 @@
max_batch_size(1),
max_workspace_size_bytes(1 << 30),
output_graph_def(nullptr),
- precision_mode(1),
+ precision_mode(TrtPrecisionMode::FP32),
minimum_segment_size(3),
graph_properties(nullptr),
cluster(nullptr),
@@ -74,7 +74,7 @@
size_t max_batch_size;
size_t max_workspace_size_bytes;
tensorflow::GraphDef* output_graph_def;
- int precision_mode;
+ TrtPrecisionMode precision_mode;
int minimum_segment_size;
const tensorflow::grappler::GraphProperties* graph_properties;
const tensorflow::grappler::Cluster* cluster;
@@ -99,9 +99,10 @@
const tensorflow::GraphDef& graph_def,
const std::vector<string>& output_names, size_t max_batch_size,
size_t max_workspace_size_bytes, tensorflow::GraphDef* new_graph_def,
- int precision_mode = 1, int minimum_segment_size = 3,
- bool is_dyn_op = false, int max_cached_engines = 1,
- std::vector<int> cached_engine_batches = {}, bool use_calibration = true);
+ TrtPrecisionMode precision_mode = TrtPrecisionMode::FP32,
+ int minimum_segment_size = 3, bool is_dyn_op = false,
+ int max_cached_engines = 1, std::vector<int> cached_engine_batches = {},
+ bool use_calibration = true);
// Method to call from optimization pass
tensorflow::Status ConvertAfterShapes(ConversionParams& params);
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc
index a3c3a8a..cabc6cc 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph_test.cc
@@ -98,7 +98,8 @@
grappler::GraphProperties graph_properties(item);
TF_EXPECT_OK(graph_properties.InferStatically(true));
- for (const int precision_mode : {FP32MODE, INT8MODE}) {
+ for (const TrtPrecisionMode precision_mode :
+ {TrtPrecisionMode::FP32, TrtPrecisionMode::INT8}) {
TrtCandidateSelector selector(graph_properties, precision_mode);
TF_EXPECT_OK(selector.IsTensorRTCandidate(matmul.operation.node()));
ExpectStatus(
@@ -113,7 +114,7 @@
matmul_with_incompatible_input.operation.node()),
error::INTERNAL,
"Failed to convert input with index 0 to a TRT_TensorOrWeights");
- if (precision_mode == INT8MODE) {
+ if (precision_mode == TrtPrecisionMode::INT8) {
TF_EXPECT_OK(selector.IsTensorRTCandidate(quantize.operation.node()));
} else {
ExpectStatus(selector.IsTensorRTCandidate(quantize.operation.node()),
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
index ff6b6d2..79b1cba 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.cc
@@ -24,7 +24,9 @@
#include <utility>
#include <vector>
+#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
#include "tensorflow/compiler/tf2tensorrt/plugin/trt_plugin_factory.h"
#include "tensorflow/compiler/tf2tensorrt/utils/trt_logger.h"
@@ -82,6 +84,13 @@
const char* const kInputPHName = "TensorRTInputPH_";
const char* const kOutputPHName = "TensorRTOutputPH_";
+bool IsEngineInput(absl::string_view name) {
+ return absl::StartsWith(name, kInputPHName);
+}
+bool IsEngineOutput(absl::string_view name) {
+ return absl::StartsWith(name, kOutputPHName);
+}
+
namespace convert {
using absl::StrAppend;
using absl::StrCat;
@@ -351,6 +360,26 @@
return trt_tensor;
}
+tensorflow::Status CreateBroadcastableScalarConstant(
+ OpConverterParams* params, float value, const nvinfer1::Dims& dims,
+ const nvinfer1::ITensor** tensor) {
+ // In order to be broadcastable, the number of dims has to match.
+ nvinfer1::Dims broadcastable_dims(dims);
+ for (int i = 0; i < broadcastable_dims.nbDims; i++) {
+ broadcastable_dims.d[i] = 1;
+ }
+ TRT_ShapedWeights weights = params->weight_store->GetTempWeights(
+ tensorflow::DataType::DT_FLOAT, broadcastable_dims);
+ auto weights_ptr =
+ static_cast<float*>(const_cast<void*>(weights.GetValues()));
+ weights_ptr[0] = value;
+ *tensor = params->converter->CreateConstantLayer(weights, broadcastable_dims);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(*tensor, params->node_def.name());
+ params->converter->ProvideQuantizationRange(
+ const_cast<nvinfer1::ITensor*>(*tensor), value, value);
+ return Status::OK();
+}
+
inline bool DimsEqual(const nvinfer1::Dims& dim_l,
const nvinfer1::Dims& dim_r) {
if (dim_l.nbDims != dim_r.nbDims) {
@@ -865,7 +894,7 @@
}
Converter::Converter(nvinfer1::INetworkDefinition* trt_network,
- int precision_mode, bool use_calibration)
+ TrtPrecisionMode precision_mode, bool use_calibration)
: trt_network_(trt_network),
precision_mode_(precision_mode),
use_calibration_(use_calibration) {
@@ -900,7 +929,7 @@
// in ConvertIdentity.
if (output.is_tensor()) {
const char* tensor_name = output.tensor()->getName();
- if (!tensorflow::str_util::StartsWith(tensor_name, kInputPHName)) {
+ if (!IsEngineInput(tensor_name)) {
// TRT initializes tensor names as "(Unnamed ITensor* N)". We rename
// them to match their corresponding TensorFlow name.
// Note: ITensors that we create internally within TF-TRT which are
@@ -960,15 +989,17 @@
return errors::NotFound("Output tensor not found: ",
output.source_tensor_name);
}
- // Check if this tensor has already been marked as an output.
+ // Check if this tensor has already been marked as an input or output.
+ //
// ConvertIdentity can cause the same tensor to be repeated in
// output_tensors, which can cause us to overwrite the name of the output
// tensor binding. For example, if we rename OutputPH_0 to OutputPH_1 then
// we won't be able to locate OutputPH_0 during runtime. To fix this,
// duplicate the tensor using no-op shuffle.
+ //
// TODO(tmorris): Remove this work-around once we use TRT's IIdentityLayer
// in ConvertIdentity.
- if (tensorflow::str_util::StartsWith(tensor->getName(), kOutputPHName)) {
+ if (IsEngineInput(tensor->getName()) || IsEngineOutput(tensor->getName())) {
// Using shuffle layer for identity by not setting reshape or transpose.
nvinfer1::IShuffleLayer* layer = network()->addShuffle(*tensor);
TFTRT_RETURN_ERROR_IF_NULLPTR(
@@ -1128,7 +1159,7 @@
} else {
*tensor = CreateConstantLayer(input.weights(), dims);
TFTRT_RETURN_ERROR_IF_NULLPTR(*tensor, "TF-TRT Internal Reshape");
- if (precision_mode() == INT8MODE && !use_calibration()) {
+ if (precision_mode() == TrtPrecisionMode::INT8 && !use_calibration()) {
// If we are in int8 mode and not calibrating, we need to explicitly set a
// quantization range for the output tensor of the IConstantLayer. Here we
// set the range to [min(weights), max(weights)].
@@ -1163,7 +1194,7 @@
}
void Converter::MaybeApplyQuantizationRanges() {
- if (precision_mode() != INT8MODE) return;
+ if (precision_mode() != TrtPrecisionMode::INT8) return;
// Infer ranges across marked ops.
PropagateQuantizationRanges();
@@ -1286,6 +1317,39 @@
return tensorflow::Status::OK();
}
+// Checks that the number of inputs match, and enforces that the inputs marked
+// as true are constant weights. true means that the input must be a weight,
+// while false means the input must be a tensor. In the future, false will mean
+// the input can be a tensor or weight.
+tensorflow::Status CheckInputsWeights(
+ const OpConverterParams& params,
+ const std::vector<std::pair<string, bool>>& inputs_is_weight) {
+ const auto& inputs = params.inputs;
+ const auto& node_def = params.node_def;
+ if (inputs.size() != inputs_is_weight.size()) {
+ return tensorflow::errors::InvalidArgument(
+ node_def.op(), " got ", inputs.size(), " inputs but expected ",
+ inputs_is_weight.size(), ", at ", node_def.name());
+ }
+ for (int i = 0; i < inputs.size(); i++) {
+ if (inputs_is_weight[i].second && inputs.at(i).is_tensor()) {
+ return tensorflow::errors::Unimplemented(
+ "The input \"", inputs_is_weight[i].first, "\" for ", node_def.op(),
+ " must be a constant, at ", node_def.name());
+ }
+ // TODO(tmorris): Remove this check and provide a method to automatically
+ // retrive an input as a tensor, converting via CreateConstantLayer if it
+ // was originally a weight. We will want a caching mechanism to prevent many
+ // duplicate constants from being created.
+ if (!inputs_is_weight[i].second && inputs.at(i).is_weights()) {
+ return tensorflow::errors::Unimplemented(
+ "The input \"", inputs_is_weight[i].first, "\" for ", node_def.op(),
+ " must be a tensor, at ", node_def.name());
+ }
+ }
+ return tensorflow::Status::OK();
+}
+
TRT_ShapedWeights ConvertFP32ToFP16(TrtWeightStore* store,
const TRT_ShapedWeights& weights_src) {
auto dtype_new = tensorflow::DataType::DT_HALF;
@@ -1478,7 +1542,7 @@
const_cast<nvinfer1::ITensor*>(tensor), permutation, &tensor));
}
- if (params->converter->precision_mode() == FP16MODE) {
+ if (params->converter->precision_mode() == TrtPrecisionMode::FP16) {
weights = ConvertFP32ToFP16(params->weight_store, weights);
}
@@ -1521,7 +1585,7 @@
// Because of this issue, fall back to BinaryTensorOpTensor if we are
// doing INT8 with no calibration. There is most likely no performance
// penalty by falling back here.
- if (params->converter->precision_mode() == INT8MODE &&
+ if (params->converter->precision_mode() == TrtPrecisionMode::INT8 &&
!params->converter->use_calibration()) {
return errors::Unimplemented(
"Intermediate quantization range cannot be determined without"
@@ -1571,25 +1635,24 @@
return tensorflow::Status::OK();
}
-enum class ConvolutionType { DEFAULT, DEPTHWISE_CONV };
-
-tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group) {
+tensorflow::Status ConvertConv2DHelper(OpConverterParams* params, int group,
+ bool is_conv2d_backprop_input) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
- if (inputs.size() != 2) {
- return tensorflow::errors::InvalidArgument("Two inputs are expected for ",
- node_def.op(), ", at ",
- node_def.name());
- }
- if (inputs.at(0).is_weights()) {
- return tensorflow::errors::Unimplemented(
- node_def.op(), " is only implemented for tensors, not weights, at ",
- node_def.name());
- }
- if (inputs.at(1).is_tensor()) {
- return tensorflow::errors::Unimplemented("Kernel for ", node_def.op(),
- " must be constant weights, at ",
- node_def.name());
+ TRT_TensorOrWeights backprop_output_size;
+ const nvinfer1::ITensor* tensor = nullptr;
+ if (is_conv2d_backprop_input) {
+ // In the case when Conv2dBackpropInput is used for conv2d_transpose, these
+ // inputs correspond to: output size, filter, and input.
+ TF_RETURN_IF_ERROR(CheckInputsWeights(
+ *params,
+ {{"input_sizes", true}, {"filter", true}, {"out_backprop", false}}));
+ backprop_output_size = inputs.at(0);
+ tensor = inputs.at(2).tensor();
+ } else {
+ TF_RETURN_IF_ERROR(
+ CheckInputsWeights(*params, {{"input", false}, {"filter", true}}));
+ tensor = inputs.at(0).tensor();
}
TRT_ShapedWeights weights_rsck = inputs.at(1).weights();
if (weights_rsck.shape_.nbDims != 4) {
@@ -1613,6 +1676,11 @@
node_def.name());
}
const nvinfer1::DimsHW dilation(tf_dilations[h_index], tf_dilations[w_index]);
+ if (is_conv2d_backprop_input && (dilation.d[0] != 1 || dilation.d[1] != 1)) {
+ return tensorflow::errors::Unimplemented(
+ "Dilation with Conv2DBackpropInput (conv2d_transpose) is not supported",
+ ", at ", node_def.name());
+ }
const auto tf_stride = attrs.get<std::vector<int>>("strides");
if (tf_stride.size() != 4) {
@@ -1628,8 +1696,6 @@
const nvinfer1::DimsHW stride(tf_stride[h_index], tf_stride[w_index]);
if (params->validation_only) return tensorflow::Status::OK();
- const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
-
// Transpose to NCHW (NCHW is required for IConvLayer).
const bool need_transpose = (data_format == "NHWC");
if (need_transpose) {
@@ -1639,19 +1705,23 @@
// Dimensions of transposed tensor.
const auto tensor_dim = tensor->getDimensions();
- // For depthwise convolution, group will be 0 so set num_groups to size of
- // input's channel dim. For a non-depthwise conv, num_groups will be 1.
+ // group == 0 signifies that this is a depthwise convolution, so set
+ // num_groups to size of input's channel dim. For a non-depthwise conv,
+ // num_groups will be 1.
const int num_groups = (group == 0) ? tensor_dim.d[0] : group;
- if (params->converter->precision_mode() == FP16MODE) {
- weights_rsck =
- ConvertFP32ToFP16(params->weight_store, inputs.at(1).weights());
+ if (params->converter->precision_mode() == TrtPrecisionMode::FP16) {
+ weights_rsck = ConvertFP32ToFP16(params->weight_store, weights_rsck);
}
+ // For conv, TF weights are RSCK, and TRT expects KCRS.
+ // For backprop, TF weights are RSKC, and TRT expects CKRS.
+ // Therefore, this reorder will work for both cases.
TRT_ShapedWeights weights =
params->weight_store->GetTempWeights(weights_rsck);
ReorderRSCKToKCRS(weights_rsck, &weights, num_groups);
TRT_ShapedWeights biases(weights.type_);
- const int noutput = weights.shape_.d[0] * num_groups;
+ const int output_axis = is_conv2d_backprop_input ? 1 : 0;
+ const int noutput = weights.shape_.d[output_axis] * num_groups;
nvinfer1::DimsHW kernel_size;
kernel_size.h() = weights.shape_.d[2];
kernel_size.w() = weights.shape_.d[3];
@@ -1662,9 +1732,23 @@
nvinfer1::DimsHW effective_kernel_size = kernel_size;
effective_kernel_size.h() += (kernel_size.h() - 1) * (dilation.h() - 1);
effective_kernel_size.w() += (kernel_size.w() - 1) * (dilation.w() - 1);
- padding = CreateSamePadding(
- stride, effective_kernel_size,
- {static_cast<int>(tensor_dim.d[1]), static_cast<int>(tensor_dim.d[2])});
+ std::vector<int64_t> input_dims;
+ if (is_conv2d_backprop_input) {
+ // For backprop, calculate padding based on "input_sizes" input, which
+ // actually corresponds to output size. ("input_sizes" makes sense in the
+ // context of Conv2DBackpropInput).
+ // We use h_index and w_index instead of 1 and 2 because we havent
+ // transposed backprop_output_size along with the input.
+ auto output_size_weights = static_cast<int*>(
+ const_cast<void*>(backprop_output_size.weights().GetValues()));
+ input_dims = {output_size_weights[h_index], output_size_weights[w_index]};
+ } else {
+ // Use 1 and 2 because tensor_dim has the dimensions of the transposed
+ // input.
+ input_dims = {static_cast<int>(tensor_dim.d[1]),
+ static_cast<int>(tensor_dim.d[2])};
+ }
+ padding = CreateSamePadding(stride, effective_kernel_size, input_dims);
} else {
padding = {{0, 0}, {0, 0}};
}
@@ -1683,17 +1767,32 @@
}
// Add convolution.
- nvinfer1::IConvolutionLayer* layer =
- params->converter->network()->addConvolution(
- *const_cast<nvinfer1::ITensor*>(tensor), noutput, kernel_size,
- weights.GetTrtWeights(), biases.GetTrtWeights());
- TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
- layer->setStride(stride);
- layer->setPadding({padding[0].first, padding[1].first});
- layer->setName(node_def.name().c_str());
- layer->setNbGroups(num_groups);
- layer->setDilation(dilation);
- const nvinfer1::ITensor* output_tensor = layer->getOutput(0);
+ nvinfer1::ILayer* conv_layer = nullptr;
+ if (is_conv2d_backprop_input) {
+ nvinfer1::IDeconvolutionLayer* layer =
+ params->converter->network()->addDeconvolution(
+ *const_cast<nvinfer1::ITensor*>(tensor), noutput, kernel_size,
+ weights.GetTrtWeights(), biases.GetTrtWeights());
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
+ layer->setStride(stride);
+ layer->setPadding({padding[0].first, padding[1].first});
+ layer->setName(node_def.name().c_str());
+ layer->setNbGroups(num_groups);
+ conv_layer = layer;
+ } else {
+ nvinfer1::IConvolutionLayer* layer =
+ params->converter->network()->addConvolution(
+ *const_cast<nvinfer1::ITensor*>(tensor), noutput, kernel_size,
+ weights.GetTrtWeights(), biases.GetTrtWeights());
+ TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
+ layer->setStride(stride);
+ layer->setPadding({padding[0].first, padding[1].first});
+ layer->setName(node_def.name().c_str());
+ layer->setNbGroups(num_groups);
+ layer->setDilation(dilation);
+ conv_layer = layer;
+ }
+ const nvinfer1::ITensor* output_tensor = conv_layer->getOutput(0);
// Restore transpose.
if (need_transpose) {
@@ -1706,18 +1805,6 @@
return tensorflow::Status::OK();
}
-tensorflow::Status ConvertConv2DHelper(OpConverterParams* params,
- ConvolutionType type) {
- switch (type) {
- case ConvolutionType::DEFAULT:
- return ConvertConv2DHelper(params, 1);
- case ConvolutionType::DEPTHWISE_CONV:
- return ConvertConv2DHelper(params, 0);
- }
- return tensorflow::errors::Unimplemented("Unsupported convolution type, at ",
- params->node_def.name());
-}
-
Status BinaryTensorOpTensor(OpConverterParams* params,
const TRT_TensorOrWeights& operand_l,
const TRT_TensorOrWeights& operand_r) {
@@ -1827,12 +1914,8 @@
tensorflow::Status ConvertTranspose(OpConverterParams* params) {
const auto& inputs = params->inputs;
- if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
- !inputs.at(1).is_weights()) {
- return tensorflow::errors::InvalidArgument(
- "Input expects tensor and weights, at ", params->node_def.name());
- }
-
+ TF_RETURN_IF_ERROR(
+ CheckInputsWeights(*params, {{"x", false}, {"perm", true}}));
// Get the permutation from weights.
TRT_ShapedWeights weights = inputs.at(1).weights();
const int* weights_ptr =
@@ -1865,11 +1948,8 @@
tensorflow::Status ConvertReshape(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
- if (inputs.size() != 2 || !inputs.at(1).is_weights()) {
- return tensorflow::errors::InvalidArgument(
- "Input expects weights for shape, at ", node_def.name());
- }
-
+ TF_RETURN_IF_ERROR(
+ CheckInputsWeights(*params, {{"tensor", false}, {"shape", true}}));
TRT_TensorOrWeights input_tensor = inputs.at(0);
TRT_ShapedWeights weights = inputs.at(1).weights();
if (weights.count() == 0) {
@@ -1965,18 +2045,8 @@
tensorflow::Status ConvertExpandDims(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
- if (inputs.size() != 2) {
- return tensorflow::errors::InvalidArgument(
- "Two inputs expected for ExpandDims, at ", node_def.name());
- }
- if (inputs.at(0).is_weights()) {
- return tensorflow::errors::Unimplemented(
- "ExpandDims expects tensor for input, at ", node_def.name());
- }
- if (!inputs.at(1).is_weights()) {
- return tensorflow::errors::InvalidArgument(
- "ExpandDims expects weights for axis, at ", node_def.name());
- }
+ TF_RETURN_IF_ERROR(
+ CheckInputsWeights(*params, {{"input", false}, {"axis", true}}));
// Get input shape as vector.
TRT_TensorOrWeights input_tensor = inputs.at(0);
const nvinfer1::Dims dims = input_tensor.GetTrtDims();
@@ -2026,14 +2096,7 @@
tensorflow::Status ConvertSqueeze(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
- if (inputs.size() != 1) {
- return tensorflow::errors::InvalidArgument(
- "One input expected for Squeeze, at ", node_def.name());
- }
- if (inputs.at(0).is_weights()) {
- return tensorflow::errors::Unimplemented(
- "Squeeze expects tensor for input, at ", node_def.name());
- }
+ TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
// Get input shape.
TRT_TensorOrWeights input_tensor = inputs.at(0);
const nvinfer1::Dims dims = input_tensor.GetTrtDims();
@@ -2135,20 +2198,9 @@
tensorflow::Status ConvertStridedSlice(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
- if (inputs.size() != 4) {
- return tensorflow::errors::InvalidArgument(
- "StridedSlice expects 4 inputs, at ", node_def.name());
- }
- if (!inputs.at(1).is_weights() || !inputs.at(2).is_weights() ||
- !inputs.at(3).is_weights()) {
- return tensorflow::errors::InvalidArgument(
- "StridedSlice expects weights for begin, end, and strides, at ",
- node_def.name());
- }
- if (!inputs.at(0).is_tensor()) {
- return tensorflow::errors::Unimplemented(
- "StridedSlice is only implemented for tensors, at ", node_def.name());
- }
+ TF_RETURN_IF_ERROR(CheckInputsWeights(
+ *params,
+ {{"input", false}, {"begin", true}, {"end", true}, {"strides", true}}));
// Get input dims.
nvinfer1::Dims dims = inputs.at(0).GetTrtDims();
std::vector<int> input_dims(dims.d, dims.d + dims.nbDims);
@@ -2329,21 +2381,21 @@
}
tensorflow::Status ConvertConv2D(OpConverterParams* params) {
- return ConvertConv2DHelper(params, ConvolutionType::DEFAULT);
+ return ConvertConv2DHelper(params, 1, /*is_conv2d_backprop_input=*/false);
}
tensorflow::Status ConvertConv2DDepthwise(OpConverterParams* params) {
- return ConvertConv2DHelper(params, ConvolutionType::DEPTHWISE_CONV);
+ return ConvertConv2DHelper(params, 0, /*is_conv2d_backprop_input=*/false);
+}
+
+tensorflow::Status ConvertConv2DBackpropInput(OpConverterParams* params) {
+ return ConvertConv2DHelper(params, 1, /*is_conv2d_backprop_input=*/true);
}
tensorflow::Status ConvertPool(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
- if (inputs.at(0).is_weights()) {
- return tensorflow::errors::Unimplemented(
- node_def.op(), " is only implemented for tensors, not weights, at ",
- node_def.name());
- }
+ TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
nvinfer1::PoolingType type;
if (node_def.op() == "MaxPool") {
type = nvinfer1::PoolingType::kMAX;
@@ -2430,7 +2482,9 @@
return tensorflow::Status::OK();
}
-tensorflow::Status ConvertActivation(OpConverterParams* params) {
+// TODO(tmorris): Use ActivationType::kLEAKY_RELU in TRT 5.1+ once perf
+// improves.
+tensorflow::Status ConvertLeakyRelu(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
if (inputs.size() != 1) {
@@ -2442,6 +2496,47 @@
node_def.op(), " is only implemented for tensors, at ",
node_def.name());
}
+ TFAttrs attrs(node_def);
+ const float alpha = attrs.get<float>("alpha");
+ if (alpha < 0.0f || alpha > 1.0f) {
+ return tensorflow::errors::Unimplemented(
+ "Alpha value for LeakyRelu must be between 0 and 1, at ",
+ node_def.name());
+ }
+ if (params->validation_only) return tensorflow::Status::OK();
+
+ // Input Tensor
+ const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
+ // Create const for alpha.
+ const nvinfer1::ITensor* const_alpha_tensor = nullptr;
+ TF_RETURN_IF_ERROR(CreateBroadcastableScalarConstant(
+ params, alpha, tensor->getDimensions(), &const_alpha_tensor));
+ // alpha * x
+ nvinfer1::IElementWiseLayer* mul_layer =
+ params->converter->network()->addElementWise(
+ *const_cast<nvinfer1::ITensor*>(tensor),
+ *const_cast<nvinfer1::ITensor*>(const_alpha_tensor),
+ nvinfer1::ElementWiseOperation::kPROD);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(mul_layer, node_def.name());
+ // max(x, alpha * x)
+ nvinfer1::IElementWiseLayer* max_layer =
+ params->converter->network()->addElementWise(
+ *const_cast<nvinfer1::ITensor*>(tensor),
+ *const_cast<nvinfer1::ITensor*>(mul_layer->getOutput(0)),
+ nvinfer1::ElementWiseOperation::kMAX);
+ TFTRT_RETURN_ERROR_IF_NULLPTR(max_layer, node_def.name());
+ nvinfer1::ITensor* output_tensor = max_layer->getOutput(0);
+ params->converter->MarkQuantizationRangesAsInferrable(
+ output_tensor, const_cast<nvinfer1::ITensor*>(mul_layer->getOutput(0)));
+
+ params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
+ return Status::OK();
+}
+
+tensorflow::Status ConvertActivation(OpConverterParams* params) {
+ const auto& inputs = params->inputs;
+ const auto& node_def = params->node_def;
+ TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
static const std::unordered_map<string, nvinfer1::ActivationType> ops{
{"Relu", nvinfer1::ActivationType::kRELU},
{"Sigmoid", nvinfer1::ActivationType::kSIGMOID},
@@ -2475,19 +2570,19 @@
Status ConvertQuantize(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
- if ((inputs.size() == 0) ||
- (node_def.op() == "FakeQuantWithMinMaxArgs" && inputs.size() != 1) ||
- (node_def.op() == "FakeQuantWithMinMaxVars" && inputs.size() != 3) ||
- (node_def.op() == "QuantizeAndDequantizeV2" && inputs.size() != 3) ||
- (node_def.op() == "QuantizeAndDequantizeV3" && inputs.size() != 4)) {
- return errors::InvalidArgument("Invalid number of inputs for ",
- node_def.op(), ", at ", node_def.name());
- }
- if (inputs.at(0).is_weights()) {
- // TensorRT will automatically quantize weights, so we will ignore ranges
- // for weights.
- params->outputs->push_back(inputs.at(0));
- return Status::OK();
+ if (node_def.op() == "FakeQuantWithMinMaxArgs") {
+ TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
+ } else if (node_def.op() == "FakeQuantWithMinMaxVars") {
+ TF_RETURN_IF_ERROR(CheckInputsWeights(
+ *params, {{"input", false}, {"min", true}, {"max", true}}));
+ } else if (node_def.op() == "QuantizeAndDequantizeV2") {
+ TF_RETURN_IF_ERROR(CheckInputsWeights(
+ *params, {{"input", false}, {"input_min", true}, {"input_max", true}}));
+ } else if (node_def.op() == "QuantizeAndDequantizeV3") {
+ TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false},
+ {"input_min", true},
+ {"input_max", true},
+ {"num_bits", true}}));
}
float min_range = 0.0f;
float max_range = 0.0f;
@@ -2504,11 +2599,6 @@
node_def.op() == "QuantizeAndDequantizeV2" ||
node_def.op() == "QuantizeAndDequantizeV3") {
// Get ranges via inputs.
- if (!inputs.at(1).is_weights() || !inputs.at(2).is_weights()) {
- return errors::InvalidArgument("Min and max inputs for ", node_def.op(),
- " must be weights not tensors, at ",
- node_def.name());
- }
auto get_weights_value = [&inputs](int index) {
auto raw_weights = static_cast<float*>(
const_cast<void*>(inputs.at(index).weights().GetValues()));
@@ -2539,20 +2629,11 @@
return Status::OK();
}
-// TODO(pdavoodi): we should update relu6 implementation once TensorRT supports
-// Relu6 natively.
+// TODO(tmorris): Use ActivationType::kCLIP in TRT 5.1+ once perf improves.
tensorflow::Status ConvertRelu6(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
- if (inputs.size() != 1) {
- return tensorflow::errors::InvalidArgument(
- "Invalid number of inputs for Relu6, at ", node_def.name());
- }
- if (inputs.at(0).is_weights()) {
- return tensorflow::errors::Unimplemented(
- "Relu6 is only implemented for tensors, not weights, at ",
- node_def.name());
- }
+ TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"input", false}}));
if (params->validation_only) return Status::OK();
// ***************************************************************************
// TensorRT does not implement Relu6 natively. This function converts Relu6 op
@@ -2576,24 +2657,10 @@
params->converter->ProvideQuantizationRange(relu_layer->getOutput(0), 0.0f,
6.0f);
- // Create a constant layer to store the floating point weight i.e. 6.0f This
- // tensor will be broadcasted uniformly during elementwise `min` operation.
- // The constant has to have the same rank as the input in order for TRT to
- // broadcast
- nvinfer1::Dims dims;
- dims.nbDims = relu_layer->getOutput(0)->getDimensions().nbDims;
- for (int i = 0; i < dims.nbDims; i++) {
- dims.d[i] = 1;
- }
- TRT_ShapedWeights weights = params->weight_store->GetTempWeights(
- tensorflow::DataType::DT_FLOAT, dims);
- auto weights_ptr =
- static_cast<float*>(const_cast<void*>(weights.GetValues()));
- weights_ptr[0] = 6.0f;
- nvinfer1::ITensor* const6_tensor =
- params->converter->CreateConstantLayer(weights, dims);
- TFTRT_RETURN_ERROR_IF_NULLPTR(const6_tensor, node_def.name());
- params->converter->ProvideQuantizationRange(const6_tensor, 0.0f, 6.0f);
+ // Create a constant layer to store the floating point weight i.e. 6.0f
+ const nvinfer1::ITensor* const6_tensor = nullptr;
+ TF_RETURN_IF_ERROR(CreateBroadcastableScalarConstant(
+ params, 6.0f, relu_layer->getOutput(0)->getDimensions(), &const6_tensor));
// ElementWise Min Operation
// Min op is a nop for INT8 execution path, as the input tensor
@@ -2601,7 +2668,8 @@
nvinfer1::IElementWiseLayer* relu6_layer =
params->converter->network()->addElementWise(
*const_cast<nvinfer1::ITensor*>(relu_layer->getOutput(0)),
- *const6_tensor, nvinfer1::ElementWiseOperation::kMIN);
+ *const_cast<nvinfer1::ITensor*>(const6_tensor),
+ nvinfer1::ElementWiseOperation::kMIN);
TFTRT_RETURN_ERROR_IF_NULLPTR(relu6_layer, node_def.name());
nvinfer1::ITensor* output_tensor = relu6_layer->getOutput(0);
params->converter->ProvideQuantizationRange(output_tensor, 0.0f, 6.0f);
@@ -2613,11 +2681,8 @@
tensorflow::Status ConvertBiasAdd(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
- if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
- !inputs.at(1).is_weights()) {
- return errors::InvalidArgument("Input expects tensor and weights, at ",
- node_def.name());
- }
+ TF_RETURN_IF_ERROR(
+ CheckInputsWeights(*params, {{"value", false}, {"bias", true}}));
TFAttrs attrs(node_def);
tensorflow::DataType tf_dtype = attrs.get<tensorflow::DataType>("T");
if (tf_dtype != DataType::DT_FLOAT && tf_dtype != DataType::DT_HALF) {
@@ -2675,7 +2740,7 @@
}
TRT_ShapedWeights weights = inputs.at(1).weights();
- if (params->converter->precision_mode() == FP16MODE) {
+ if (params->converter->precision_mode() == TrtPrecisionMode::FP16) {
weights = ConvertFP32ToFP16(params->weight_store, weights);
}
nvinfer1::ScaleMode mode = nvinfer1::ScaleMode::kCHANNEL;
@@ -2837,9 +2902,13 @@
Status ConvertBinary(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
+ // TODO(tmorris): Enable once false is updated to mean either tensor or weight
+ // TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}, {"y",
+ // false}}));
if (inputs.size() != 2) {
- return errors::InvalidArgument("Binary ops require two inputs, at ",
- node_def.name());
+ return tensorflow::errors::InvalidArgument(
+ node_def.op(), " got ", inputs.size(), " inputs but expected 2, at ",
+ node_def.name());
}
// Constant folding should have been done by TensorFlow
@@ -2889,11 +2958,7 @@
{"Abs", nvinfer1::UnaryOperation::kABS},
{"Reciprocal", nvinfer1::UnaryOperation::kRECIP},
};
-
- if (inputs.size() != 1) {
- return tensorflow::errors::FailedPrecondition(
- "Unary ops require single tensor input, at ", node_def.name());
- }
+ TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}}));
// TODO(jie): check type
const nvinfer1::ITensor* tensor = nullptr;
@@ -2908,7 +2973,7 @@
// x -> [Sqrt] -> sqrt(x) -> [Recip] -> 1/sqrt(x)
// ^
// need range here
- if (params->converter->precision_mode() == INT8MODE &&
+ if (params->converter->precision_mode() == TrtPrecisionMode::INT8 &&
!params->converter->use_calibration()) {
return errors::Unimplemented(
"Intermediate quantization range cannot be determined without"
@@ -2942,14 +3007,7 @@
tensorflow::Status ConvertSquare(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
- if (inputs.size() != 1) {
- return tensorflow::errors::InvalidArgument("Square expects one input, at ",
- node_def.name());
- }
- if (inputs.at(0).is_weights()) {
- return tensorflow::errors::Unimplemented(
- "Square is only implemented for tensors, at ", node_def.name());
- }
+ TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}}));
if (params->validation_only) return Status::OK();
// Constant 2 with same rank as input
@@ -2981,11 +3039,8 @@
tensorflow::Status ConvertReduce(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
- if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
- !inputs.at(1).is_weights()) {
- return tensorflow::errors::InvalidArgument(
- "Input expects tensor and weights, at", node_def.name());
- }
+ TF_RETURN_IF_ERROR(
+ CheckInputsWeights(*params, {{"input", false}, {"axis", true}}));
const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
TRT_ShapedWeights index_list = inputs.at(1).weights();
@@ -3046,12 +3101,8 @@
tensorflow::Status ConvertPad(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
- // TODO(aaroey): make a routine for this check and reuse it.
- if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
- !inputs.at(1).is_weights()) {
- return tensorflow::errors::InvalidArgument(
- "Input expects tensor and weights, at", node_def.name());
- }
+ TF_RETURN_IF_ERROR(
+ CheckInputsWeights(*params, {{"tensor", false}, {"paddings", true}}));
// Implement tensor binaryOp weight [channel wise] for now;
const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
@@ -3232,6 +3283,11 @@
tensorflow::Status ConvertFusedBatchNorm(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
+ TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false},
+ {"scale", true},
+ {"offset", true},
+ {"mean", true},
+ {"variance", true}}));
TFAttrs attrs(node_def);
float epsilon = attrs.get<float>("epsilon");
auto data_format = attrs.get<string>("data_format");
@@ -3252,21 +3308,6 @@
node_def.op(), " only supports is_training=false, at ",
node_def.name());
}
- if (inputs.at(0).is_weights()) {
- return tensorflow::errors::Unimplemented(
- node_def.op(),
- " is only implemented for tensor inputs, not weights, at ",
- node_def.name());
- }
- for (int i = 1; i < 5; i++) {
- if (inputs.at(i).is_tensor()) {
- return tensorflow::errors::Unimplemented(
- node_def.op(),
- " must have constant inputs for scale, offset, mean and variance, "
- "at ",
- node_def.name());
- }
- }
nvinfer1::ITensor const* tensor = inputs.at(0).tensor();
// Check parameter types
@@ -3423,11 +3464,7 @@
tensorflow::Status ConvertMatMul(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
- if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
- !inputs.at(1).is_weights()) {
- return errors::InvalidArgument("Input expects tensor and weights, at ",
- node_def.name());
- }
+ TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"a", false}, {"b", true}}));
TFAttrs attrs(node_def);
tensorflow::DataType tf_dtype = attrs.get<tensorflow::DataType>("T");
@@ -3453,6 +3490,14 @@
tensorflow::Status ConvertBatchMatMul(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
+ // TODO(tmorris): Enable once false is updated to mean either tensor or weight
+ // TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}, {"y",
+ // false}}));
+ if (inputs.size() != 2) {
+ return tensorflow::errors::InvalidArgument(
+ node_def.op(), " got ", inputs.size(), " inputs but expected 2, at ",
+ node_def.name());
+ }
TFAttrs attrs(node_def);
tensorflow::DataType tf_dtype = attrs.get<tensorflow::DataType>("T");
@@ -3524,6 +3569,7 @@
tensorflow::Status ConvertSoftmax(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
+ TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"logits", false}}));
const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
int nbDims = tensor->getDimensions().nbDims;
@@ -3532,6 +3578,8 @@
"TensorRT Softmax cannot apply on batch dimension, at" +
node_def.name());
}
+ if (params->validation_only) return Status::OK();
+
nvinfer1::ISoftMaxLayer* layer = params->converter->network()->addSoftMax(
*const_cast<nvinfer1::ITensor*>(tensor));
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
@@ -3547,31 +3595,36 @@
tensorflow::Status ConvertTopK(OpConverterParams* params) {
const auto& inputs = params->inputs;
- const auto& node_def = params->node_def;
- const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
+ if (inputs.size() != 2 || !inputs.at(0).is_tensor() ||
+ !inputs.at(1).is_weights()) {
+ return errors::InvalidArgument("Input expects tensor and weights, at ",
+ params->node_def.name());
+ }
- int nbDims = tensor->getDimensions().nbDims;
- if (nbDims == 0) {
- return tensorflow::errors::InvalidArgument(
- "TensorRT TopK cannot apply on batch dimension, at" + node_def.name());
+ const auto& node_def = params->node_def;
+ TF_RETURN_IF_ERROR(
+ CheckInputsWeights(*params, {{"input", false}, {"k", true}}));
+ const nvinfer1::ITensor* tensor = inputs.at(0).tensor();
+ const int num_dims = tensor->getDimensions().nbDims;
+ if (num_dims == 0) {
+ return errors::InvalidArgument(
+ "TensorRT TopK cannot apply on batch dimension, at", node_def.name());
}
TRT_ShapedWeights k_w = inputs.at(1).weights();
- int k = *(static_cast<int*>(const_cast<void*>(k_w.GetValues())));
-
- nvinfer1::TopKOperation op;
- uint32_t reducedAxes = 0;
- if (node_def.op() == "TopKV2") {
- op = nvinfer1::TopKOperation::kMAX;
- reducedAxes |= 1 << (nbDims - 1);
- } else {
- return tensorflow::errors::Unimplemented(
- "Operation: ", node_def.op(),
- " not implemented, at: ", node_def.name());
+ if (k_w.count() != 1) {
+ return errors::InvalidArgument("k value of TopK should be a scalar, at",
+ node_def.name());
}
+ // Note that ITopKLayer always have sorted outputs, so we don't need to handle
+ // the 'sorted' attribute of the node.
+ if (params->validation_only) return Status::OK();
+ const nvinfer1::TopKOperation op = nvinfer1::TopKOperation::kMAX;
+ const int k = *(static_cast<int*>(const_cast<void*>(k_w.GetValues())));
+ const uint32_t reduce_axes = 1 << (num_dims - 1);
nvinfer1::ITopKLayer* layer = params->converter->network()->addTopK(
- *const_cast<nvinfer1::ITensor*>(tensor), op, k, reducedAxes);
+ *const_cast<nvinfer1::ITensor*>(tensor), op, k, reduce_axes);
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
nvinfer1::ITensor* output_value_tensor = layer->getOutput(0);
@@ -3588,8 +3641,10 @@
(*registration)["ConcatV2"] = ConvertConcat;
(*registration)["Const"] = ConvertConst;
(*registration)["Conv2D"] = ConvertConv2D;
+ (*registration)["Conv2DBackpropInput"] = ConvertConv2DBackpropInput;
(*registration)["DepthwiseConv2dNative"] = ConvertConv2DDepthwise;
(*registration)["ExpandDims"] = ConvertExpandDims;
+ (*registration)["LeakyRelu"] = ConvertLeakyRelu;
(*registration)["MatMul"] = ConvertMatMul;
(*registration)["Pad"] = ConvertPad;
(*registration)["Relu6"] = ConvertRelu6;
@@ -3598,6 +3653,7 @@
(*registration)["Squeeze"] = ConvertSqueeze;
(*registration)["StridedSlice"] = ConvertStridedSlice;
(*registration)["Transpose"] = ConvertTranspose;
+ (*registration)["TopKV2"] = ConvertTopK;
for (auto quantization_op_type :
{"QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3",
@@ -3644,14 +3700,13 @@
op_registry_["Mean"] = ConvertReduce;
op_registry_["Softmax"] = ConvertSoftmax;
op_registry_["BatchMatMul"] = ConvertBatchMatMul;
- op_registry_["TopKV2"] = ConvertTopK;
plugin_converter_ = ConvertPlugin;
}
tensorflow::Status ConvertGraphDefToEngine(
- const tensorflow::GraphDef& gdef, int precision_mode, int max_batch_size,
- size_t max_workspace_size_bytes,
+ const tensorflow::GraphDef& gdef, TrtPrecisionMode precision_mode,
+ int max_batch_size, size_t max_workspace_size_bytes,
const std::vector<tensorflow::PartialTensorShape>& input_shapes,
Logger* logger, nvinfer1::IGpuAllocator* allocator,
TRTInt8Calibrator* calibrator,
@@ -3666,9 +3721,9 @@
builder->setMaxBatchSize(max_batch_size);
builder->setMaxWorkspaceSize(max_workspace_size_bytes);
builder->setGpuAllocator(allocator);
- if (precision_mode == FP16MODE) {
+ if (precision_mode == TrtPrecisionMode::FP16) {
builder->setHalf2Mode(true);
- } else if (precision_mode == INT8MODE) {
+ } else if (precision_mode == TrtPrecisionMode::INT8) {
builder->setInt8Mode(true);
if (use_calibration) {
builder->setInt8Calibrator(calibrator);
@@ -3693,8 +3748,7 @@
for (const auto& node_def : gdef.node()) {
string node_name = node_def.name();
VLOG(2) << "Converting op name=" << node_name << ", op=" << node_def.op();
- if (tensorflow::str_util::StartsWith(node_name, kInputPHName) &&
- (node_def.op() == "Placeholder")) {
+ if (IsEngineInput(node_name) && (node_def.op() == "Placeholder")) {
int32 slot_number = -1;
if (!tensorflow::strings::safe_strto32( // non-absl ok
node_name.c_str() + strlen(kInputPHName), &slot_number)) {
@@ -3722,8 +3776,7 @@
// engines offline, by calling sess.run() and cache/serialize the engines.
TF_RETURN_IF_ERROR(
converter.AddInputTensor(node_name, trt_dtype, trt_dims, batch_size));
- } else if (tensorflow::str_util::StartsWith(node_name, kOutputPHName) &&
- (node_def.op() == "Identity")) {
+ } else if (IsEngineOutput(node_name) && (node_def.op() == "Identity")) {
int32 slot_number = -1;
if (!tensorflow::strings::safe_strto32( // non-absl ok
node_name.c_str() + strlen(kOutputPHName), &slot_number)) {
@@ -3877,7 +3930,7 @@
TensorId input = ParseTensorName(snode->input(input_idx));
if (!subgraph_node_names.count(
string(input.first.data(), input.first.size())) &&
- !str_util::StartsWith(input.first, kInputPHName)) {
+ !IsEngineInput(input.first)) {
if (input.second == Graph::kControlSlot) {
VLOG(1) << "... removing control inputs " << input.first
<< " from subgraph.";
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h
index 0a33d00..d1e30eb 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes.h
@@ -92,7 +92,7 @@
EngineInfo()
: engine_type(EngineType::TRTStatic),
max_workspace_size_bytes(0),
- precision_mode(FP32MODE),
+ precision_mode(TrtPrecisionMode::FP32),
use_calibration(true) {}
string engine_name;
@@ -109,7 +109,7 @@
int64 max_workspace_size_bytes;
int maximum_cached_engines;
std::vector<int> cached_engine_batches;
- int precision_mode;
+ TrtPrecisionMode precision_mode;
bool use_calibration;
};
@@ -141,8 +141,8 @@
// is successful. This is different than successfully building the engine:
// building can still fail afterwards.
tensorflow::Status ConvertGraphDefToEngine(
- const tensorflow::GraphDef& gdef, int precision_mode, int max_batch_size,
- size_t max_workspace_size_bytes,
+ const tensorflow::GraphDef& gdef, TrtPrecisionMode precision_mode,
+ int max_batch_size, size_t max_workspace_size_bytes,
const std::vector<tensorflow::PartialTensorShape>& input_shapes,
Logger* logger, nvinfer1::IGpuAllocator* allocator,
TRTInt8Calibrator* calibrator,
@@ -178,6 +178,8 @@
nvinfer1::Weights GetTrtWeights() const;
+ // Returns the raw pointer to the underlying buffer which holds the weights
+ // value.
void* GetValues() const {
return const_cast<char*>(tensor_.tensor_data().data());
}
@@ -400,21 +402,6 @@
// Class to convert TF nodes to TRT network.
class Converter {
public:
- Converter(nvinfer1::INetworkDefinition* trt_network, int precision_mode,
- bool use_calibration);
-
- //////////////////////////////////////////////////////////////////////////////
- // Methods used by the TRT engine builder to build a TRT network from a TF
- // function/subgraph.
-
- // Convert the node to TRT network.
- Status ConvertNode(const tensorflow::NodeDef& node_def);
-
- // Add input tensor to the TRT network with given 'name', 'dtype', 'dims' and
- // 'batch_size'.
- Status AddInputTensor(const string& name, nvinfer1::DataType dtype,
- const nvinfer1::Dims& dims, int batch_size);
-
// Used for Converter::RenameAndMarkOutputTensors()
struct EngineOutputInfo {
// The TRT tensor name which produces the output.
@@ -428,6 +415,21 @@
nvinfer1::DataType trt_dtype;
};
+ Converter(nvinfer1::INetworkDefinition* trt_network,
+ TrtPrecisionMode precision_mode, bool use_calibration);
+
+ //////////////////////////////////////////////////////////////////////////////
+ // Methods used by the TRT engine builder to build a TRT network from a TF
+ // function/subgraph.
+
+ // Convert the node to TRT network.
+ Status ConvertNode(const tensorflow::NodeDef& node_def);
+
+ // Add input tensor to the TRT network with given 'name', 'dtype', 'dims' and
+ // 'batch_size'.
+ Status AddInputTensor(const string& name, nvinfer1::DataType dtype,
+ const nvinfer1::Dims& dims, int batch_size);
+
// Mark the tensors with names specified by source_tensor_name as output of
// the TRT network, and set their names in the TRT network as dest_node_name.
Status RenameAndMarkOutputTensors(
@@ -442,7 +444,7 @@
nvinfer1::INetworkDefinition* network() { return trt_network_; }
// What precision are we targeting?
- int precision_mode() const { return precision_mode_; }
+ TrtPrecisionMode precision_mode() const { return precision_mode_; }
// Calibration will be or was previously performed on this network?
bool use_calibration() const { return use_calibration_; }
@@ -544,7 +546,7 @@
std::vector<std::pair<nvinfer1::ITensor*, nvinfer1::ITensor*>>
quantization_infer_;
- const int precision_mode_;
+ const TrtPrecisionMode precision_mode_;
const bool use_calibration_;
diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
index 39f4fb5..77221f6 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
+++ b/tensorflow/compiler/tf2tensorrt/convert/convert_nodes_test.cc
@@ -21,7 +21,11 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>
+#include "absl/strings/match.h"
+#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
+#include "absl/strings/string_view.h"
+#include "absl/types/span.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/standard_ops.h"
@@ -36,6 +40,7 @@
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
@@ -154,7 +159,7 @@
}
template <typename T>
-void ExpectArrayNear(const std::vector<T>& lhs, const std::vector<T>& rhs) {
+void ExpectArrayNear(const std::vector<T>& lhs, absl::Span<const T> rhs) {
ASSERT_EQ(lhs.size(), rhs.size());
for (int i = 0; i < lhs.size(); i++) {
EXPECT_FLOAT_EQ(lhs[i], rhs[i]);
@@ -165,7 +170,7 @@
// EXPECT_FLOAT_EQ.
template <>
void ExpectArrayNear(const std::vector<Eigen::half>& lhs,
- const std::vector<Eigen::half>& rhs) {
+ absl::Span<const Eigen::half> rhs) {
ASSERT_EQ(lhs.size(), rhs.size());
for (int i = 0; i < lhs.size(); i++) {
EXPECT_FLOAT_EQ(Eigen::half_impl::half_to_float(lhs[i]),
@@ -480,8 +485,7 @@
ConverterTest() {
builder_.reset(nvinfer1::createInferBuilder(logger_));
network_.reset(builder_->createNetwork());
- converter_.reset(new Converter(network_.get(),
- /*precision_mode=*/FP32MODE,
+ converter_.reset(new Converter(network_.get(), TrtPrecisionMode::FP32,
/*use_calibration=*/false));
weight_store_ = &converter_->weight_store_;
}
@@ -783,7 +787,7 @@
// input -> infer1 -> infer2 -> infer3
FakeITensor input, infer_1, infer_2, infer_3;
FakeITensor not_infer;
- Converter int8_converter(/*trt_network=*/nullptr, INT8MODE,
+ Converter int8_converter(/*trt_network=*/nullptr, TrtPrecisionMode::INT8,
/*use_calibration=*/true);
int8_converter.ProvideQuantizationRange(&input, -5.0f, 5.0f);
int8_converter.ProvideQuantizationRange(¬_infer, -100.0f, 100.0f);
@@ -928,6 +932,83 @@
}
}
+class ConvertGraphDefToEngineTest : public ::testing::Test {
+ public:
+ Status RunConvertGraphDefToEngine(Scope* s) {
+ GraphDef gdef;
+ TF_EXPECT_OK(s->ToGraphDef(&gdef));
+ std::vector<tensorflow::PartialTensorShape> input_shapes;
+ int batch_size = -1;
+ for (const NodeDef& node : gdef.node()) {
+ absl::string_view node_name(node.name());
+ if (str_util::ConsumePrefix(&node_name, kInputPHName)) {
+ int port = -1;
+ EXPECT_TRUE(absl::SimpleAtoi(node_name, &port)) << node.name();
+ if (input_shapes.size() < port + 1) input_shapes.resize(port + 1);
+ input_shapes[port] =
+ PartialTensorShape(node.attr().at("shape").shape());
+ if (batch_size == -1) {
+ batch_size = input_shapes[port].dim_size(0);
+ } else {
+ EXPECT_EQ(batch_size, input_shapes[port].dim_size(0));
+ }
+ }
+ }
+ // TODO(laigd): execute the engine and get outputs.
+ return ConvertGraphDefToEngine(
+ gdef, TrtPrecisionMode::FP32, /*max_batch_size=*/1,
+ /*max_workspace_size_bytes=*/64 << 20, input_shapes, &logger_,
+ /*allocator=*/nullptr, /*calibrator=*/nullptr, &engine_,
+ /*use_calibration=*/false, /*convert_successfully=*/nullptr);
+ }
+
+ protected:
+ TrtUniquePtrType<nvinfer1::ICudaEngine> engine_;
+
+ private:
+ Logger logger_;
+};
+
+TEST_F(ConvertGraphDefToEngineTest, IdentityGraph) {
+ Scope s = Scope::NewRootScope();
+ auto input = ops::Placeholder(s.WithOpName(StrCat(kInputPHName, 0)), DT_FLOAT,
+ ops::Placeholder::Shape({1, 1}));
+ auto output = ops::Identity(s.WithOpName("identity1"), input);
+ output = ops::Identity(s.WithOpName("identity2"), output);
+ output = ops::Identity(s.WithOpName(StrCat(kOutputPHName, 0)), output);
+ // If the converter marks the input tensor as output tensor, the conversion
+ // below will fail with:
+ // > TensorRTOutputPH_0 cannot be both input and output
+ // > Network must have at least one output
+ TF_EXPECT_OK(RunConvertGraphDefToEngine(&s));
+}
+
+// Input/output data format for OpConverterTest::BuildAndRun().
+struct InputOutputData {
+ void* Buffer() const {
+ return const_cast<char*>(tensor.tensor_data().data());
+ }
+
+ size_t TotalBytes() const { return tensor.TotalBytes(); }
+
+ const char* name;
+ Tensor tensor;
+};
+
+template <typename T>
+Tensor ConstructTensor(int data_size, const T& value = T()) {
+ std::vector<T> values(data_size, value);
+ return test::AsTensor<T>(values);
+}
+
+using DataVec = std::vector<InputOutputData>;
+
+template <typename T>
+inline absl::Span<const T> GetSpanForData(const InputOutputData& data) {
+ const auto& tensor_map = data.tensor.flat<T>();
+ return absl::Span<const T>(tensor_map.data(), tensor_map.size());
+}
+
// Class to test various op converters, using both a TrtNodeValidator and
// Converter.
class OpConverterTest : public ::testing::Test {
@@ -953,11 +1034,11 @@
builder_.reset(nvinfer1::createInferBuilder(logger_));
network_.reset(builder_->createNetwork());
builder_->setMaxBatchSize(1);
+ builder_->setMaxWorkspaceSize(1 << 26);
// Reset the validator and converter.
validator_.reset(new TrtNodeValidator);
- converter_.reset(new Converter(network_.get(),
- /*precision_mode=*/FP32MODE,
+ converter_.reset(new Converter(network_.get(), TrtPrecisionMode::FP32,
/*use_calibration=*/false));
// Reset other related artifacts.
@@ -966,15 +1047,14 @@
}
// TODO(laigd): test fp16 and int8 support.
- template <typename T>
- void BuildAndRun(
- const std::vector<std::pair<const char*, const std::vector<T>>>&
- input_data,
- const char* output_name, std::vector<T>* output_data) {
+ void BuildAndRun(const DataVec& input_data, DataVec* output_data) {
// Mark the output tensor as TRT engine output.
- TF_EXPECT_OK(converter_->RenameAndMarkOutputTensors(
- {{string(output_name), string(output_name),
- TfDataTypeToTrt(DataTypeToEnum<T>::v())}}));
+ std::vector<Converter::EngineOutputInfo> output_info;
+ for (const auto& data : *output_data) {
+ output_info.push_back(
+ {data.name, data.name, TfDataTypeToTrt(data.tensor.dtype())});
+ }
+ TF_EXPECT_OK(converter_->RenameAndMarkOutputTensors(output_info));
// Build the TRT engine.
ASSERT_EQ(nullptr, engine_.get());
@@ -982,31 +1062,44 @@
CHECK_NOTNULL(engine_.get());
// Execute the TRT engine.
- ASSERT_LE(input_data.size() + 1, 3);
- void* buffers[3];
- for (const auto name_and_data : input_data) {
- const int input_size = name_and_data.second.size() * sizeof(T);
- const int input_index = engine_->getBindingIndex(name_and_data.first);
- ASSERT_EQ(0, cudaMalloc(&buffers[input_index], input_size));
- ASSERT_EQ(
- 0, cudaMemcpyAsync(buffers[input_index], name_and_data.second.data(),
- input_size, cudaMemcpyHostToDevice, stream_));
+ const int num_bindings = input_data.size() + output_data->size();
+ std::vector<void*> buffers(num_bindings);
+
+ for (const auto& data : input_data) {
+ const int input_index = engine_->getBindingIndex(data.name);
+ ASSERT_EQ(0, cudaMalloc(&buffers[input_index], data.TotalBytes()));
+ ASSERT_EQ(0, cudaMemcpyAsync(buffers[input_index], data.Buffer(),
+ data.TotalBytes(), cudaMemcpyHostToDevice,
+ stream_));
+ }
+ struct SizeAndIndex {
+ SizeAndIndex(int in_size, int in_index)
+ : size(in_size), index(in_index) {}
+ int size;
+ int index;
+ };
+ std::vector<SizeAndIndex> output_infos;
+ for (const auto& data : *output_data) {
+ const int output_index = engine_->getBindingIndex(data.name);
+ output_infos.emplace_back(data.TotalBytes(), output_index);
+ ASSERT_EQ(0, cudaMalloc(&buffers[output_index], data.TotalBytes()));
}
- const int output_size = output_data->size() * sizeof(T);
- const int output_index = engine_->getBindingIndex(output_name);
- ASSERT_EQ(0, cudaMalloc(&buffers[output_index], output_size));
-
- ASSERT_EQ(engine_->getNbBindings(), input_data.size() + 1);
-
+ ASSERT_EQ(engine_->getNbBindings(), num_bindings);
TrtUniquePtrType<nvinfer1::IExecutionContext> execution_context(
engine_->createExecutionContext());
- execution_context->enqueue(/*batchSize=*/1, buffers, stream_, nullptr);
- ASSERT_EQ(0, cudaMemcpyAsync(output_data->data(), buffers[output_index],
- output_size, cudaMemcpyDeviceToHost, stream_));
+ execution_context->enqueue(/*batchSize=*/1, buffers.data(), stream_,
+ nullptr);
+
+ for (int i = 0; i < output_infos.size(); ++i) {
+ const auto& output_info = output_infos[i];
+ ASSERT_EQ(0, cudaMemcpyAsync(output_data->at(i).Buffer(),
+ buffers[output_info.index], output_info.size,
+ cudaMemcpyDeviceToHost, stream_));
+ }
cudaStreamSynchronize(stream_);
- for (int i = 0; i < input_data.size() + 1; ++i) {
+ for (int i = 0; i < num_bindings; ++i) {
ASSERT_EQ(0, cudaFree(buffers[i]));
}
}
@@ -1254,7 +1347,7 @@
NodeDef node_def = MakeNodeDef("my_transpose", "Transpose", {});
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
- "Input expects tensor and weights, at my_transpose");
+ "Transpose got 0 inputs but expected 2, at my_transpose");
}
// Get the NodeDef for Transpose.
@@ -1270,8 +1363,8 @@
AddTestTensor("input", {1, 2, 3});
AddTestTensor("weights", {3});
RunValidationAndConversion(
- node_def, error::INVALID_ARGUMENT,
- "Input expects tensor and weights, at my_transpose");
+ node_def, error::UNIMPLEMENTED,
+ "The input \"perm\" for Transpose must be a constant, at my_transpose");
}
{
// Transpose at batch dimension, should fail.
@@ -1301,10 +1394,12 @@
EXPECT_TRUE(output.is_tensor());
ExpectTrtDimsEqualsArray({3, 1, 2}, output.tensor()->getDimensions());
- std::vector<float> output_data(6);
- BuildAndRun<float>({{"input", {1, 2, 3, 4, 5, 6}}}, "my_transpose",
- &output_data);
- EXPECT_THAT(output_data, ElementsAre(1, 4, 2, 5, 3, 6));
+ const DataVec input_data{
+ {"input", test::AsTensor<float>({1, 2, 3, 4, 5, 6})}};
+ DataVec output_data{{"my_transpose", ConstructTensor<float>(6)}};
+ BuildAndRun(input_data, &output_data);
+ EXPECT_THAT(GetSpanForData<float>(output_data[0]),
+ ElementsAre(1, 4, 2, 5, 3, 6));
}
}
@@ -1314,7 +1409,7 @@
NodeDef node_def = MakeNodeDef("my_reshape", "Reshape", {});
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
- "Input expects weights for shape, at my_reshape");
+ "Reshape got 0 inputs but expected 2, at my_reshape");
}
// Get the NodeDef for Reshape.
@@ -1330,8 +1425,8 @@
AddTestTensor("input", {1, 2, 3});
AddTestTensor("weights", {3});
RunValidationAndConversion(
- node_def, error::INVALID_ARGUMENT,
- "Input expects weights for shape, at my_reshape");
+ node_def, error::UNIMPLEMENTED,
+ "The input \"shape\" for Reshape must be a constant, at my_reshape");
}
{
// Reshape to scalar, should fail.
@@ -1391,10 +1486,12 @@
EXPECT_TRUE(output.is_tensor());
ExpectTrtDimsEqualsArray({1, 3, 2}, output.tensor()->getDimensions());
- std::vector<float> output_data(6);
- BuildAndRun<float>({{"input", {1, 2, 3, 4, 5, 6}}}, "my_reshape",
- &output_data);
- EXPECT_THAT(output_data, ElementsAre(1, 2, 3, 4, 5, 6));
+ const DataVec input_data{
+ {"input", test::AsTensor<float>({1, 2, 3, 4, 5, 6})}};
+ DataVec output_data{{"my_reshape", ConstructTensor<float>(6)}};
+ BuildAndRun(input_data, &output_data);
+ EXPECT_THAT(GetSpanForData<float>(output_data[0]),
+ ElementsAre(1, 2, 3, 4, 5, 6));
}
}
@@ -1404,7 +1501,7 @@
NodeDef node_def = MakeNodeDef("my_matmul", "MatMul", {});
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
- "Input expects tensor and weights, at my_matmul");
+ "MatMul got 0 inputs but expected 2, at my_matmul");
}
// Get the NodeDef for MatMul.
@@ -1454,12 +1551,13 @@
EXPECT_TRUE(output.is_tensor());
ExpectTrtDimsEqualsArray({2}, output.tensor()->getDimensions());
- std::vector<float> output_data(2);
- BuildAndRun<float>({{"input", {0, 1}}}, "my_matmul", &output_data);
+ const DataVec input_data{{"input", test::AsTensor<float>({0, 1})}};
+ DataVec output_data{{"my_matmul", ConstructTensor<float>(2)}};
+ BuildAndRun(input_data, &output_data);
if (transpose_b) {
- EXPECT_THAT(output_data, ElementsAre(1, 3));
+ EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAre(1, 3));
} else {
- EXPECT_THAT(output_data, ElementsAre(2, 3));
+ EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAre(2, 3));
}
}
}
@@ -1513,23 +1611,28 @@
const int num_input = TrtDimsNumElements(GetTestDims(dims_array));
ASSERT_EQ(trt_input_rank > 1 ? 6 : (data_format == "NHWC" ? 3 : 2),
num_input);
- std::vector<CType> output_data(num_input);
- test->BuildAndRun<CType>(
- {{"input", std::vector<CType>(num_input, CType(0))}}, "my_biasadd",
- &output_data);
+
+ const DataVec input_data{
+ {"input", ConstructTensor<CType>(num_input, CType(0))}};
+ DataVec output_data{{"my_biasadd", ConstructTensor<CType>(num_input)}};
+ test->BuildAndRun(input_data, &output_data);
if (trt_input_rank == 1) {
if (data_format == "NHWC") {
- EXPECT_THAT(output_data, ElementsAre(CType(1), CType(2), CType(3)));
+ EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
+ ElementsAre(CType(1), CType(2), CType(3)));
} else {
- EXPECT_THAT(output_data, ElementsAre(CType(1), CType(2)));
+ EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
+ ElementsAre(CType(1), CType(2)));
}
} else {
if (data_format == "NHWC") {
- EXPECT_THAT(output_data, ElementsAre(CType(1), CType(2), CType(3),
- CType(1), CType(2), CType(3)));
+ EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
+ ElementsAre(CType(1), CType(2), CType(3), CType(1),
+ CType(2), CType(3)));
} else {
- EXPECT_THAT(output_data, ElementsAre(CType(1), CType(1), CType(1),
- CType(2), CType(2), CType(2)));
+ EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
+ ElementsAre(CType(1), CType(1), CType(1), CType(2),
+ CType(2), CType(2)));
}
}
}
@@ -1542,7 +1645,7 @@
NodeDef node_def = MakeNodeDef("my_biasadd", "BiasAdd", {});
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
- "Input expects tensor and weights, at my_biasadd");
+ "BiasAdd got 0 inputs but expected 2, at my_biasadd");
}
// OK. Note that kINT32 is not supported by IScaleLayer, so we don't test
@@ -1607,21 +1710,25 @@
EXPECT_TRUE(output.is_tensor());
ExpectTrtDimsEqualsArray({1, 1, 2}, output.tensor()->getDimensions());
- std::vector<CType> output_data(2);
- test->BuildAndRun<CType>(
- {{"input",
- /*input_data=*/swap_inputs ? operand2 : operand1}},
- "my_binary", &output_data);
+ const DataVec input_data{
+ {"input", test::AsTensor<CType>(swap_inputs ? operand2 : operand1)}};
+ DataVec output_data{{"my_binary", ConstructTensor<CType>(2)}};
+ test->BuildAndRun(input_data, &output_data);
if (node_def.op() == "Add") {
- EXPECT_THAT(output_data, ElementsAre(CType(5), CType(10.5)));
+ EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
+ ElementsAre(CType(5), CType(10.5)));
} else if (node_def.op() == "Sub") {
- EXPECT_THAT(output_data, ElementsAre(CType(1), CType(4.5)));
+ EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
+ ElementsAre(CType(1), CType(4.5)));
} else if (node_def.op() == "Mul") {
- EXPECT_THAT(output_data, ElementsAre(CType(6), CType(22.5)));
+ EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
+ ElementsAre(CType(6), CType(22.5)));
} else if (node_def.op() == "Div") {
- EXPECT_THAT(output_data, ElementsAre(CType(1.5), CType(2.5)));
+ EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
+ ElementsAre(CType(1.5), CType(2.5)));
} else if (node_def.op() == "RealDiv") {
- EXPECT_THAT(output_data, ElementsAre(CType(1.5), CType(2.5)));
+ EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
+ ElementsAre(CType(1.5), CType(2.5)));
} else {
ASSERT_TRUE(false);
}
@@ -1656,13 +1763,14 @@
EXPECT_TRUE(output.is_tensor());
ExpectTrtDimsEqualsArray({2, 1, 2}, output.tensor()->getDimensions());
- std::vector<CType> output_data(4);
- test->BuildAndRun<CType>({{"input", input}}, "my_binary", &output_data);
+ const DataVec input_data{{"input", test::AsTensor<CType>(input)}};
+ DataVec output_data{{"my_binary", ConstructTensor<CType>(4)}};
+ test->BuildAndRun(input_data, &output_data);
if (weights_dims.size() == 1) {
- EXPECT_THAT(output_data,
+ EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(11), CType(22), CType(13), CType(24)));
} else {
- EXPECT_THAT(output_data,
+ EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(11), CType(12), CType(23), CType(24)));
}
}
@@ -1690,9 +1798,10 @@
EXPECT_TRUE(output.is_tensor());
ExpectTrtDimsEqualsArray({2, 1, 2}, output.tensor()->getDimensions());
- std::vector<CType> output_data(4);
- test->BuildAndRun<CType>({{"input", input}}, "my_binary", &output_data);
- EXPECT_THAT(output_data,
+ const DataVec input_data{{"input", test::AsTensor<CType>(input)}};
+ DataVec output_data{{"my_binary", ConstructTensor<CType>(4)}};
+ test->BuildAndRun(input_data, &output_data);
+ EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(11), CType(12), CType(13), CType(14)));
}
@@ -1740,17 +1849,19 @@
// Check the result of running the engine.
const int expected_num_outputs =
TrtDimsNumElements(GetTestDims(expected_output_dims));
- std::vector<CType> output_data(expected_num_outputs);
- test->BuildAndRun<CType>(
- {{"input",
- /*input_data=*/std::vector<CType>(num_inputs, CType(2))}},
- "my_binary", &output_data);
+ const DataVec input_data{
+ {"input", ConstructTensor<CType>(num_inputs, CType(2))}};
+ DataVec output_data{
+ {"my_binary", ConstructTensor<CType>(expected_num_outputs)}};
+ test->BuildAndRun(input_data, &output_data);
if (node_def.op() == "Add") {
- EXPECT_THAT(output_data, ElementsAreArray(std::vector<CType>(
- expected_num_outputs, CType(3))));
+ EXPECT_THAT(
+ GetSpanForData<CType>(output_data[0]),
+ ElementsAreArray(std::vector<CType>(expected_num_outputs, CType(3))));
} else if (node_def.op() == "Minimum") {
- EXPECT_THAT(output_data, ElementsAreArray(std::vector<CType>(
- expected_num_outputs, CType(1))));
+ EXPECT_THAT(
+ GetSpanForData<CType>(output_data[0]),
+ ElementsAreArray(std::vector<CType>(expected_num_outputs, CType(1))));
} else {
ASSERT_TRUE(false);
}
@@ -1777,32 +1888,33 @@
EXPECT_TRUE(output.is_tensor());
ExpectTrtDimsEqualsArray({2, 2}, output.tensor()->getDimensions());
- std::vector<CType> output_data(4);
+ const DataVec input_data{
+ {"input1", test::AsTensor<CType>({CType(3), CType(6)})},
+ {"input2", test::AsTensor<CType>({CType(2), CType(3)})}};
+ DataVec output_data{{"my_binary", ConstructTensor<CType>(4)}};
// After broadcasting first input becomes {3, 6, 3, 6} and second input
// becomes {2, 3, 2, 3}.
- test->BuildAndRun<CType>(
- {{"input1", {CType(3), CType(6)}}, {"input2", {CType(2), CType(3)}}},
- "my_binary", &output_data);
+ test->BuildAndRun(input_data, &output_data);
if (node_def.op() == "Add") {
- EXPECT_THAT(output_data,
+ EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(5), CType(8), CType(6), CType(9)));
} else if (node_def.op() == "Sub") {
- EXPECT_THAT(output_data,
+ EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(1), CType(4), CType(0), CType(3)));
} else if (node_def.op() == "Mul") {
- EXPECT_THAT(output_data,
+ EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(6), CType(12), CType(9), CType(18)));
} else if (node_def.op() == "Div") {
- EXPECT_THAT(output_data,
+ EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(1.5), CType(3), CType(1), CType(2)));
} else if (node_def.op() == "RealDiv") {
- EXPECT_THAT(output_data,
+ EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(1.5), CType(3), CType(1), CType(2)));
} else if (node_def.op() == "Minimum") {
- EXPECT_THAT(output_data,
+ EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(2), CType(2), CType(3), CType(3)));
} else if (node_def.op() == "Maximum") {
- EXPECT_THAT(output_data,
+ EXPECT_THAT(GetSpanForData<CType>(output_data[0]),
ElementsAre(CType(3), CType(6), CType(3), CType(6)));
} else {
ASSERT_TRUE(false);
@@ -1816,7 +1928,9 @@
NodeDef node_def = MakeNodeDef("my_add", "Add", {num_inputs, "input"});
AddTestTensor("input", {1}, /*batch_size=*/1, nvinfer1::DataType::kFLOAT);
RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
- "Binary ops require two inputs, at my_add");
+ StrCat("Add got ", std::to_string(num_inputs),
+ " inputs but expected 2, at my_add")
+ .c_str());
}
{
// Both inputs are weights.
@@ -1886,14 +2000,18 @@
}
TEST_F(OpConverterTest, ConvertQuantize) {
- for (const string& op :
- {"FakeQuantWithMinMaxArgs", "FakeQuantWithMinMaxVars",
- "QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3"}) {
+ const std::pair<string, int> op_with_num_inputs[4] = {
+ {"FakeQuantWithMinMaxArgs", 1},
+ {"FakeQuantWithMinMaxVars", 3},
+ {"QuantizeAndDequantizeV2", 3},
+ {"QuantizeAndDequantizeV3", 4}};
+ for (const auto& pair : op_with_num_inputs) {
// Input list is empty, should fail.
- NodeDef node_def = MakeNodeDef("my_quantize", op, {});
+ NodeDef node_def = MakeNodeDef("my_quantize", pair.first, {});
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
- StrCat("Invalid number of inputs for ", op, ", at my_quantize")
+ StrCat(pair.first, " got 0 inputs but expected ",
+ std::to_string(pair.second), ", at my_quantize")
.c_str());
}
{
@@ -1980,9 +2098,9 @@
AddTestTensor("weights_min", {1});
AddTestTensor("weights_max", {1});
RunValidationAndConversion(
- node_def, error::INVALID_ARGUMENT,
- "Min and max inputs for QuantizeAndDequantizeV2 must be weights not "
- "tensors, at my_quantize");
+ node_def, error::UNIMPLEMENTED,
+ "The input \"input_min\" for QuantizeAndDequantizeV2 must be a constant"
+ ", at my_quantize");
}
{
// QuantizeAndDequantizeV3 ranges set via inputs, ok.
@@ -2015,7 +2133,7 @@
NodeDef node_def = MakeNodeDef("my_relu6", "Relu6", {});
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
- "Invalid number of inputs for Relu6, at my_relu6");
+ "Relu6 got 0 inputs but expected 1, at my_relu6");
}
// Get the NodeDef for Relu6.
@@ -2029,7 +2147,7 @@
AddTestWeights<float>("input", {1}, {1.0f});
RunValidationAndConversion(
node_def, error::UNIMPLEMENTED,
- "Relu6 is only implemented for tensors, not weights, at my_relu6");
+ "The input \"input\" for Relu6 must be a tensor, at my_relu6");
}
{
// Clip tensor values and set quantization ranges, ok.
@@ -2042,10 +2160,12 @@
auto ranges = quantization_ranges();
EXPECT_EQ(ranges[output.tensor()], 6.0f);
- std::vector<float> output_data(6);
- BuildAndRun<float>({{"input", {-100, -1, 0, 3, 5, 9}}}, "my_relu6",
- &output_data);
- EXPECT_THAT(output_data, ElementsAre(0, 0, 0, 3, 5, 6));
+ const DataVec input_data{
+ {"input", test::AsTensor<float>({-100, -1, 0, 3, 5, 9})}};
+ DataVec output_data{{"my_relu6", ConstructTensor<float>(6)}};
+ BuildAndRun(input_data, &output_data);
+ EXPECT_THAT(GetSpanForData<float>(output_data[0]),
+ ElementsAre(0, 0, 0, 3, 5, 6));
}
}
@@ -2067,24 +2187,26 @@
ExpectTrtDimsEqualsArray({1, 20}, output.tensor()->getDimensions());
const int num_inputs = 20;
- std::vector<CType> input_data(num_inputs);
- std::vector<CType> expected_output_data(num_inputs);
+ std::vector<CType> inputs(num_inputs);
+ std::vector<CType> expected_outputs(num_inputs);
for (int i = 0; i < 20; i++) {
const CType value = CType(i - 9);
- input_data[i] = value;
- expected_output_data[i] = value * value;
+ inputs[i] = value;
+ expected_outputs[i] = value * value;
}
- std::vector<CType> output_data(num_inputs);
- test->BuildAndRun<CType>({{"input", input_data}}, "my_square", &output_data);
- ExpectArrayNear(expected_output_data, output_data);
+ const DataVec input_data{{"input", test::AsTensor<CType>(inputs)}};
+ DataVec output_data{{"my_square", ConstructTensor<CType>(num_inputs)}};
+ test->BuildAndRun(input_data, &output_data);
+ ExpectArrayNear(expected_outputs, GetSpanForData<CType>(output_data[0]));
}
TEST_F(OpConverterTest, ConvertSquare) {
{
// Input list is empty, should fail.
NodeDef node_def = MakeNodeDef("my_square", "Square", {});
- RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
- "Square expects one input, at my_square");
+ RunValidationAndConversion(
+ node_def, error::INVALID_ARGUMENT,
+ "Square got 0 inputs but expected 1, at my_square");
}
{
// Input is weights, should fail.
@@ -2096,7 +2218,7 @@
AddTestWeights<float>("input", {1, 2, 3}, {1, 2, 3, 4, -5, 6});
RunValidationAndConversion(
node_def, error::UNIMPLEMENTED,
- "Square is only implemented for tensors, at my_square");
+ "The input \"x\" for Square must be a tensor, at my_square");
}
// OK. Note that kINT32 is not supported by IElementWiseLayer, so we don't
@@ -2112,7 +2234,7 @@
// Input list is empty, should fail.
NodeDef node_def = MakeNodeDef("my_act", "Relu", {});
RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
- "Relu expects one input, at my_act");
+ "Relu got 0 inputs but expected 1, at my_act");
}
{
// Input is weights, should fail.
@@ -2124,7 +2246,7 @@
AddTestWeights<int32>("input", {1, 2, 3}, {-3, -2, -1, 0, 1, 2});
RunValidationAndConversion(
node_def, error::UNIMPLEMENTED,
- "Relu is only implemented for tensors, at my_act");
+ "The input \"input\" for Relu must be a tensor, at my_act");
}
// Get nodedef for activation layer.
@@ -2168,12 +2290,14 @@
EXPECT_TRUE(output.is_tensor());
ExpectTrtDimsEqualsArray({1, 2, 3}, output.tensor()->getDimensions());
- const std::vector<float> input_data = {-100, -2, -1, 0, 1, 100};
- std::vector<float> output_data(6);
- BuildAndRun<float>({{"input", input_data}}, "my_act", &output_data);
- for (int i = 0; i < input_data.size(); i++) {
- const float expected_output = get_act_output(op_name, input_data[i]);
- EXPECT_FLOAT_EQ(output_data[i], expected_output);
+ const std::vector<float> input = {-100, -2, -1, 0, 1, 100};
+ const DataVec input_data{{"input", test::AsTensor<float>(input)}};
+ DataVec output_data{{"my_act", ConstructTensor<float>(6)}};
+ BuildAndRun(input_data, &output_data);
+ for (int i = 0; i < input.size(); i++) {
+ const float expected_output = get_act_output(op_name, input[i]);
+ EXPECT_FLOAT_EQ(GetSpanForData<float>(output_data[0])[i],
+ expected_output);
}
}
}
@@ -2184,7 +2308,7 @@
NodeDef node_def = MakeNodeDef("my_expanddims", "ExpandDims", {});
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
- "Two inputs expected for ExpandDims, at my_expanddims");
+ "ExpandDims got 0 inputs but expected 2, at my_expanddims");
}
// Get the NodeDef for ExpandDims.
@@ -2199,18 +2323,18 @@
Reset();
AddTestWeights<int32>("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6});
AddTestWeights<int32>("weights", {1}, {1});
- RunValidationAndConversion(
- node_def, error::UNIMPLEMENTED,
- "ExpandDims expects tensor for input, at my_expanddims");
+ RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
+ "The input \"input\" for ExpandDims must be a "
+ "tensor, at my_expanddims");
}
{
// Axis is a tensor, should fail.
Reset();
AddTestTensor("input", {1, 2, 3});
AddTestTensor("weights", {3});
- RunValidationAndConversion(
- node_def, error::INVALID_ARGUMENT,
- "ExpandDims expects weights for axis, at my_expanddims");
+ RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
+ "The input \"axis\" for ExpandDims must be a "
+ "constant, at my_expanddims");
}
{
// Add dim at batch dimension, should fail.
@@ -2286,10 +2410,12 @@
ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims,
output.tensor()->getDimensions());
- std::vector<float> output_data(6);
- BuildAndRun<float>({{"input", {1, 2, 3, 4, 5, 6}}}, "my_expanddims",
- &output_data);
- EXPECT_THAT(output_data, ElementsAre(1, 2, 3, 4, 5, 6));
+ const DataVec input_data{
+ {"input", test::AsTensor<float>({1, 2, 3, 4, 5, 6})}};
+ DataVec output_data{{"my_expanddims", ConstructTensor<float>(6)}};
+ BuildAndRun(input_data, &output_data);
+ EXPECT_THAT(GetSpanForData<float>(output_data[0]),
+ ElementsAre(1, 2, 3, 4, 5, 6));
}
}
@@ -2297,8 +2423,9 @@
{
// Input list is empty, should fail.
NodeDef node_def = MakeNodeDef("my_squeeze", "Squeeze", {});
- RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
- "One input expected for Squeeze, at my_squeeze");
+ RunValidationAndConversion(
+ node_def, error::INVALID_ARGUMENT,
+ "Squeeze got 0 inputs but expected 1, at my_squeeze");
}
{
// No attrs, should fail.
@@ -2331,7 +2458,7 @@
AddTestWeights<float>("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6});
RunValidationAndConversion(
node_def, error::UNIMPLEMENTED,
- "Squeeze expects tensor for input, at my_squeeze");
+ "The input \"input\" for Squeeze must be a tensor, at my_squeeze");
}
{
// Squeeze batch dim, should fail.
@@ -2406,10 +2533,12 @@
ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims,
output.tensor()->getDimensions());
- std::vector<float> output_data(6);
- BuildAndRun<float>({{"input", {1, 2, 3, 4, 5, 6}}}, "my_squeeze",
- &output_data);
- EXPECT_THAT(output_data, ElementsAre(1, 2, 3, 4, 5, 6));
+ const DataVec input_data{
+ {"input", test::AsTensor<float>({1, 2, 3, 4, 5, 6})}};
+ DataVec output_data{{"my_squeeze", ConstructTensor<float>(6)}};
+ BuildAndRun(input_data, &output_data);
+ EXPECT_THAT(GetSpanForData<float>(output_data[0]),
+ ElementsAre(1, 2, 3, 4, 5, 6));
}
}
@@ -2419,7 +2548,7 @@
NodeDef node_def = MakeNodeDef("my_strided_slice", "StridedSlice", {});
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
- "StridedSlice expects 4 inputs, at my_strided_slice");
+ "StridedSlice got 0 inputs but expected 4, at my_strided_slice");
}
// Get nodedef for StridedSlice layer.
@@ -2450,9 +2579,9 @@
AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
AddTestWeights<int32>("end", {4}, {1, 1, 2, 3});
AddTestWeights<int32>("strides", {4}, {1, 1, 1, 1});
- RunValidationAndConversion(
- node_def, error::UNIMPLEMENTED,
- "StridedSlice is only implemented for tensors, at my_strided_slice");
+ RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
+ "The input \"input\" for StridedSlice must be a "
+ "tensor, at my_strided_slice");
}
{
// Begin, end, strides are tensors, should fail.
@@ -2463,8 +2592,8 @@
AddTestTensor("end", {4});
AddTestTensor("strides", {4});
RunValidationAndConversion(
- node_def, error::INVALID_ARGUMENT,
- "StridedSlice expects weights for begin, end, and strides, at "
+ node_def, error::UNIMPLEMENTED,
+ "The input \"begin\" for StridedSlice must be a constant, at "
"my_strided_slice");
}
{
@@ -2679,10 +2808,15 @@
TRT_TensorOrWeights output;
TF_EXPECT_OK(GetTensorOrWeights("my_strided_slice", &output));
- std::vector<float> output_data(ok_params[i].expected_output.size());
- BuildAndRun<float>({{"input", {1, 2, 3, 4, 5, 6}}}, "my_strided_slice",
- &output_data);
- EXPECT_THAT(output_data, ElementsAreArray(ok_params[i].expected_output));
+
+ const DataVec input_data{
+ {"input", test::AsTensor<float>({1, 2, 3, 4, 5, 6})}};
+ DataVec output_data{
+ {"my_strided_slice",
+ ConstructTensor<float>(ok_params[i].expected_output.size())}};
+ BuildAndRun(input_data, &output_data);
+ EXPECT_THAT(GetSpanForData<float>(output_data[0]),
+ ElementsAreArray(ok_params[i].expected_output));
}
}
@@ -2692,22 +2826,34 @@
NodeDef node_def = MakeNodeDef("my_conv2d", "Conv2D", {});
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
- "Two inputs are expected for Conv2D, at my_conv2d");
+ "Conv2D got 0 inputs but expected 2, at my_conv2d");
}
// Get nodedef for Conv2D layer.
auto get_conv2d_nodedef =
[](std::vector<int> strides = {1, 1, 1, 1}, string padding = "SAME",
- string data_format = "NCHW",
- std::vector<int> dilations = {1, 1, 1, 1}) -> NodeDef {
+ string data_format = "NCHW", std::vector<int> dilations = {1, 1, 1, 1},
+ bool is_conv2d_backprop_input = false) -> NodeDef {
Scope s = Scope::NewRootScope();
auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
auto filter = ops::Placeholder(s.WithOpName("weights"), DT_FLOAT);
- ops::Conv2D::Attrs attrs =
- ops::Conv2D::Attrs().DataFormat(data_format).Dilations(dilations);
- auto conv2d = ops::Conv2D(s.WithOpName("my_conv2d"), input, filter, strides,
- padding, attrs);
- return conv2d.operation.node()->def();
+ if (is_conv2d_backprop_input) {
+ auto input_sizes =
+ ops::Placeholder(s.WithOpName("input_sizes"), DT_INT32);
+ ops::Conv2DBackpropInput::Attrs attrs = ops::Conv2DBackpropInput::Attrs()
+ .DataFormat(data_format)
+ .Dilations(dilations);
+ auto conv2d =
+ ops::Conv2DBackpropInput(s.WithOpName("my_conv2d"), input_sizes,
+ filter, input, strides, padding, attrs);
+ return conv2d.operation.node()->def();
+ } else {
+ ops::Conv2D::Attrs attrs =
+ ops::Conv2D::Attrs().DataFormat(data_format).Dilations(dilations);
+ auto conv2d = ops::Conv2D(s.WithOpName("my_conv2d"), input, filter,
+ strides, padding, attrs);
+ return conv2d.operation.node()->def();
+ }
};
{
@@ -2718,7 +2864,7 @@
AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
RunValidationAndConversion(
node_def, error::UNIMPLEMENTED,
- "Conv2D is only implemented for tensors, not weights, at my_conv2d");
+ "The input \"input\" for Conv2D must be a tensor, at my_conv2d");
}
{
// Filter is tensor, should fail.
@@ -2728,7 +2874,7 @@
AddTestTensor("weights", {3, 3, 1, 1});
RunValidationAndConversion(
node_def, error::UNIMPLEMENTED,
- "Kernel for Conv2D must be constant weights, at my_conv2d");
+ "The input \"filter\" for Conv2D must be a constant, at my_conv2d");
}
{
// Filter is not 4D, should fail.
@@ -2774,6 +2920,19 @@
"dimensions, at my_conv2d");
}
{
+ // Dilation + Conv2DBackpropInput, should fail.
+ Reset();
+ NodeDef node_def =
+ get_conv2d_nodedef({1, 1, 1, 1}, "SAME", "NHWC", {1, 1, 2, 1}, true);
+ AddTestTensor("input", {2, 3, 1});
+ AddTestWeights<float>("weights", {3, 3, 1, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9});
+ AddTestWeights<int>("input_sizes", {4}, {1, 2, 3, 1});
+ RunValidationAndConversion(node_def, error::UNIMPLEMENTED,
+ "Dilation with Conv2DBackpropInput "
+ "(conv2d_transpose) is not supported, "
+ "at my_conv2d");
+ }
+ {
// Strides is not 4D, should fail.
Reset();
NodeDef node_def =
@@ -2803,6 +2962,7 @@
const std::vector<float>& filter,
const std::vector<int>& strides, const string& padding,
const string& data_format, const std::vector<int>& dilations,
+ bool is_conv2d_backprop_input,
const std::vector<int>& expected_output_dims,
const std::vector<float>& expected_output)
: input_dims(input_dims),
@@ -2813,6 +2973,7 @@
padding(padding),
data_format(data_format),
dilations(dilations),
+ is_conv2d_backprop_input(is_conv2d_backprop_input),
expected_output_dims(expected_output_dims),
expected_output(expected_output) {}
@@ -2824,12 +2985,13 @@
string padding;
string data_format;
std::vector<int> dilations;
+ bool is_conv2d_backprop_input;
std::vector<int> expected_output_dims;
std::vector<float> expected_output;
};
// Ok.
- const int kConv2DOKCases = 6;
+ const int kConv2DOKCases = 7;
TestParams ok_params[kConv2DOKCases] = {
// Basic
TestParams{/*input_dims=*/{1, 2, 3},
@@ -2840,6 +3002,7 @@
/*padding=*/"VALID",
/*data_format=*/"NCHW",
/*dilations=*/{1, 1, 1, 1},
+ /*is_conv2d_backprop_input=*/false,
/*expected_output_dims=*/{1, 2, 2},
/*expected_output=*/{1, 1, 0, 1}},
// SAME padding (Asymmetric)
@@ -2851,6 +3014,7 @@
/*padding=*/"SAME",
/*data_format=*/"NCHW",
/*dilations=*/{1, 1, 1, 1},
+ /*is_conv2d_backprop_input=*/false,
/*expected_output_dims=*/{1, 2, 3},
/*expected_output=*/{1, 1, -2, 0, 1, -4}},
// SAME padding (Symmetric)
@@ -2862,6 +3026,7 @@
/*padding=*/"SAME",
/*data_format=*/"NCHW",
/*dilations=*/{1, 1, 1, 1},
+ /*is_conv2d_backprop_input=*/false,
/*expected_output_dims=*/{1, 2, 3},
/*expected_output=*/{1, 2, -1, 3, 1, -3}},
// NHWC
@@ -2873,6 +3038,7 @@
/*padding=*/"VALID",
/*data_format=*/"NHWC",
/*dilations=*/{1, 1, 1, 1},
+ /*is_conv2d_backprop_input=*/false,
/*expected_output_dims=*/{2, 2, 1},
/*expected_output=*/{1, 1, 0, 1}},
// Dilated
@@ -2884,6 +3050,7 @@
/*padding=*/"VALID",
/*data_format=*/"NCHW",
/*dilations=*/{1, 1, 1, 2},
+ /*is_conv2d_backprop_input=*/false,
/*expected_output_dims=*/{1, 2, 1},
/*expected_output=*/{2, 1}},
// Strided
@@ -2895,28 +3062,104 @@
/*padding=*/"VALID",
/*data_format=*/"NCHW",
/*dilations=*/{1, 1, 1, 1},
+ /*is_conv2d_backprop_input=*/false,
/*expected_output_dims=*/{1, 2, 2},
/*expected_output=*/{1, 0, 1, 3}},
+ // Transpose Strided
+ TestParams{/*input_dims=*/{1, 2, 2},
+ /*input=*/{0, 1, 2, 3},
+ /*filter_dims=*/{1, 2, 1, 1},
+ /*filter=*/{-1, 1},
+ /*strides=*/{1, 1, 1, 2},
+ /*padding=*/"SAME",
+ /*data_format=*/"NCHW",
+ /*dilations=*/{1, 1, 1, 1},
+ /*is_conv2d_backprop_input=*/true,
+ /*expected_output_dims=*/{1, 2, 4},
+ /*expected_output=*/{0, 0, -1, 1, -2, 2, -3, 3}},
};
for (int i = 0; i < kConv2DOKCases; i++) {
Reset();
- NodeDef node_def =
- get_conv2d_nodedef(ok_params[i].strides, ok_params[i].padding,
- ok_params[i].data_format, ok_params[i].dilations);
+ NodeDef node_def = get_conv2d_nodedef(
+ ok_params[i].strides, ok_params[i].padding, ok_params[i].data_format,
+ ok_params[i].dilations, ok_params[i].is_conv2d_backprop_input);
AddTestTensor("input", ok_params[i].input_dims);
AddTestWeights<float>("weights", ok_params[i].filter_dims,
ok_params[i].filter);
+ if (ok_params[i].is_conv2d_backprop_input) {
+ AddTestWeights<float>(
+ "input_sizes",
+ {static_cast<int>(ok_params[i].expected_output.size())},
+ ok_params[i].expected_output);
+ }
RunValidationAndConversion(node_def);
TRT_TensorOrWeights output;
TF_EXPECT_OK(GetTensorOrWeights("my_conv2d", &output));
EXPECT_TRUE(output.is_tensor());
ExpectTrtDimsEqualsArray(ok_params[i].expected_output_dims,
output.tensor()->getDimensions());
- std::vector<float> output_data(ok_params[i].expected_output.size());
- BuildAndRun<float>({{"input", ok_params[i].input}}, "my_conv2d",
- &output_data);
- EXPECT_THAT(output_data, ElementsAreArray(ok_params[i].expected_output));
+
+ const DataVec input_data{
+ {"input", test::AsTensor<float>(ok_params[i].input)}};
+ DataVec output_data{
+ {"my_conv2d",
+ ConstructTensor<float>(ok_params[i].expected_output.size())}};
+ BuildAndRun(input_data, &output_data);
+ EXPECT_THAT(GetSpanForData<float>(output_data[0]),
+ ElementsAreArray(ok_params[i].expected_output));
+ }
+}
+
+TEST_F(OpConverterTest, ConvertTopK) {
+ {
+ // Input list is empty, should fail.
+ NodeDef node_def = MakeNodeDef("my_topk", "TopKV2", {});
+ RunValidationAndConversion(node_def, error::INVALID_ARGUMENT,
+ "Input expects tensor and weights, at my_topk");
+ }
+
+ for (const auto dtype : {DT_FLOAT, DT_INT32}) {
+ // Get the NodeDef for TopKV2.
+ Scope s = Scope::NewRootScope();
+ auto input = ops::Placeholder(s.WithOpName("input"), dtype);
+ auto weights = ops::Placeholder(s.WithOpName("weights"), DT_INT32);
+ auto topk = ops::TopK(s.WithOpName("my_topk"), input, weights);
+ const NodeDef& node_def = topk.operation.node()->def();
+ {
+ // K is a tensor, should fail.
+ Reset();
+ AddTestTensor("input", {1, 2, 3}, /*batch_size=*/1,
+ /*trt_dtype=*/TfDataTypeToTrt(dtype));
+ AddTestTensor("weights", {2});
+ RunValidationAndConversion(
+ node_def, error::INVALID_ARGUMENT,
+ "Input expects tensor and weights, at my_topk");
+ }
+ {
+ // Ok.
+ Reset();
+ AddTestTensor("input", {1, 2, 5});
+ AddTestWeights<int32>("weights", {1}, {2});
+ RunValidationAndConversion(node_def);
+ TRT_TensorOrWeights outputs[2];
+ TF_EXPECT_OK(GetTensorOrWeights("my_topk", &outputs[0]));
+ TF_EXPECT_OK(GetTensorOrWeights("my_topk:1", &outputs[1]));
+ for (auto& output : outputs) {
+ EXPECT_TRUE(output.is_tensor());
+ ExpectTrtDimsEqualsArray({1, 2, 2}, output.tensor()->getDimensions());
+ }
+
+ const DataVec input_data{
+ {"input", test::AsTensor<float>({-9, 3, 5, 1, 6, -5, 7, 1, 0, -1})}};
+ DataVec output_data{{"my_topk", ConstructTensor<float>(4)},
+ {"my_topk:1", ConstructTensor<int32>(4)}};
+ BuildAndRun(input_data, &output_data);
+ EXPECT_THAT(GetSpanForData<float>(output_data[0]),
+ ElementsAre(6, 5, 7, 1));
+ EXPECT_THAT(GetSpanForData<int32>(output_data[1]),
+ ElementsAre(4, 2, 1, 2));
+ }
}
}
diff --git a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc
index ebf8df1..f36aa55 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc
+++ b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.cc
@@ -66,7 +66,7 @@
max_workspace_size_bytes_ = params.at("max_workspace_size_bytes").i();
}
if (params.count("precision_mode")) {
- TF_RETURN_IF_ERROR(GetPrecisionMode(
+ TF_RETURN_IF_ERROR(TrtPrecisionModeFromName(
Uppercase(params.at("precision_mode").s()), &precision_mode_));
}
if (params.count("use_calibration")) {
@@ -227,7 +227,7 @@
TF_RETURN_IF_ERROR(static_graph_properties.InferStatically(true));
tensorflow::tensorrt::convert::ConversionParams cp;
- if (use_calibration_ && precision_mode_ != INT8MODE) {
+ if (use_calibration_ && precision_mode_ != TrtPrecisionMode::INT8) {
VLOG(1) << "Calibration with FP32 or FP16 is not implemented. "
<< "Falling back to use_calibration = False."
<< "Note that the default value of use_calibration is True.";
diff --git a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h
index bd6c6db..b2aed2a 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h
+++ b/tensorflow/compiler/tf2tensorrt/convert/trt_optimization_pass.h
@@ -18,6 +18,7 @@
#include <string>
+#include "tensorflow/compiler/tf2tensorrt/convert/utils.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
#include "tensorflow/core/platform/logging.h"
@@ -34,7 +35,7 @@
TRTOptimizationPass(const string& name = "TRTOptimizationPass")
: name_(name),
minimum_segment_size_(3),
- precision_mode_(0),
+ precision_mode_(TrtPrecisionMode::FP32),
maximum_batch_size_(-1),
is_dynamic_op_(false),
max_cached_batches_(1),
@@ -62,7 +63,7 @@
private:
const string name_;
int minimum_segment_size_;
- int precision_mode_;
+ TrtPrecisionMode precision_mode_;
int maximum_batch_size_;
bool is_dynamic_op_;
std::vector<int> batches_;
diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.cc b/tensorflow/compiler/tf2tensorrt/convert/utils.cc
index 62a0f62..0ca3a5a 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/utils.cc
+++ b/tensorflow/compiler/tf2tensorrt/convert/utils.cc
@@ -34,33 +34,32 @@
#endif
}
-Status GetPrecisionModeName(const int precision_mode, string* name) {
- switch (precision_mode) {
- case FP32MODE:
+Status TrtPrecisionModeToName(TrtPrecisionMode mode, string* name) {
+ switch (mode) {
+ case TrtPrecisionMode::FP32:
*name = "FP32";
break;
- case FP16MODE:
+ case TrtPrecisionMode::FP16:
*name = "FP16";
break;
- case INT8MODE:
+ case TrtPrecisionMode::INT8:
*name = "INT8";
break;
default:
- return tensorflow::errors::OutOfRange("Unknown precision mode");
+ return errors::OutOfRange("Unknown precision mode");
}
return Status::OK();
}
-Status GetPrecisionMode(const string& name, int* precision_mode) {
+Status TrtPrecisionModeFromName(const string& name, TrtPrecisionMode* mode) {
if (name == "FP32") {
- *precision_mode = FP32MODE;
+ *mode = TrtPrecisionMode::FP32;
} else if (name == "FP16") {
- *precision_mode = FP16MODE;
+ *mode = TrtPrecisionMode::FP16;
} else if (name == "INT8") {
- *precision_mode = INT8MODE;
+ *mode = TrtPrecisionMode::INT8;
} else {
- return tensorflow::errors::InvalidArgument("Invalid precision mode name: ",
- name);
+ return errors::InvalidArgument("Invalid precision mode name: ", name);
}
return Status::OK();
}
diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.h b/tensorflow/compiler/tf2tensorrt/convert/utils.h
index 9f9ee59..0aa602d 100644
--- a/tensorflow/compiler/tf2tensorrt/convert/utils.h
+++ b/tensorflow/compiler/tf2tensorrt/convert/utils.h
@@ -35,14 +35,11 @@
bool IsGoogleTensorRTEnabled();
-// TODO(aaroey): use an enum instead.
-const int FP32MODE = 0;
-const int FP16MODE = 1;
-const int INT8MODE = 2;
+enum class TrtPrecisionMode { FP32, FP16, INT8 };
-Status GetPrecisionModeName(const int precision_mode, string* name);
+Status TrtPrecisionModeToName(TrtPrecisionMode mode, string* name);
-Status GetPrecisionMode(const string& name, int* precision_mode);
+Status TrtPrecisionModeFromName(const string& name, TrtPrecisionMode* mode);
} // namespace tensorrt
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc
index ec9dc08..bc5335e 100644
--- a/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc
+++ b/tensorflow/compiler/tf2tensorrt/kernels/trt_engine_op.cc
@@ -113,7 +113,7 @@
GraphDef segment_graph_;
// Engine Precision mode.
- int precision_mode_;
+ TrtPrecisionMode precision_mode_;
// Whether engine is constructed during the conversion or needs to be
// constructed from protobuf segment.
@@ -210,11 +210,13 @@
context->GetAttr("calibration_data", &calibration_data));
OP_REQUIRES_OK(context,
context->GetAttr("segment_funcdef_name", &funcdef_name_));
- OP_REQUIRES_OK(context, GetPrecisionMode(precision_string, &precision_mode_));
+ OP_REQUIRES_OK(context,
+ TrtPrecisionModeFromName(precision_string, &precision_mode_));
OP_REQUIRES_OK(context,
context->GetAttr("use_calibration", &use_calibration_));
- calibration_mode_ = (use_calibration_ && precision_mode_ == INT8MODE &&
- calibration_data.size() == 0);
+ calibration_mode_ =
+ (use_calibration_ && precision_mode_ == TrtPrecisionMode::INT8 &&
+ calibration_data.size() == 0);
if (calibration_data.size()) {
calibrator_.reset(new TRTInt8Calibrator(calibration_data));
calibration_data.resize(0);
@@ -712,9 +714,10 @@
// TODO(aaroey): maybe setting the max batch size using the python
// calibration wrapper class.
auto s = convert::ConvertGraphDefToEngine(
- *segment_graph, INT8MODE, cres->calibrator_->getBatchSize(),
- workspace_size_bytes, shapes, &cres->logger_, cres->allocator_.get(),
- cres->calibrator_.get(), &cres->engine_,
+ *segment_graph, TrtPrecisionMode::INT8,
+ cres->calibrator_->getBatchSize(), workspace_size_bytes, shapes,
+ &cres->logger_, cres->allocator_.get(), cres->calibrator_.get(),
+ &cres->engine_,
/*use_calibration=*/true,
/*convert_successfully=*/nullptr);
if (!s.ok()) {
diff --git a/tensorflow/compiler/tf2tensorrt/utils/trt_logger.cc b/tensorflow/compiler/tf2tensorrt/utils/trt_logger.cc
index f454f55..6bc842e 100644
--- a/tensorflow/compiler/tf2tensorrt/utils/trt_logger.cc
+++ b/tensorflow/compiler/tf2tensorrt/utils/trt_logger.cc
@@ -26,7 +26,7 @@
void Logger::log(Severity severity, const char* msg) {
// Suppress info-level messages
switch (severity) {
-#if NV_TENSORRT_MAJOR >= 5 && NV_TENSORRT_MINOR >= 1
+#if NV_TENSORRT_MAJOR > 5 || (NV_TENSORRT_MAJOR == 5 && NV_TENSORRT_MINOR >= 1)
case Severity::kVERBOSE:
#endif
case Severity::kINFO: { // Mark TRT info messages as debug!
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index 14bd9d4..02de951 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -1,6 +1,6 @@
licenses(["notice"]) # Apache 2.0
-load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test")
+load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test", "tf_cuda_cc_test")
package_group(
name = "internal",
@@ -679,3 +679,25 @@
"@com_google_absl//absl/strings",
],
)
+
+tf_cuda_cc_test(
+ name = "fused_batchnorm_reserve_space_test",
+ size = "medium",
+ srcs = ["fused_batchnorm_reserve_space_test.cc"],
+ deps = [
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/compiler/jit",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:tensorflow",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/kernels:ops_testutil",
+ "//tensorflow/core/kernels:ops_util",
+ "@com_google_absl//absl/algorithm:container",
+ ],
+)
diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.cc b/tensorflow/compiler/tf2xla/functionalize_cond.cc
index e780c93..c8341a2 100644
--- a/tensorflow/compiler/tf2xla/functionalize_cond.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_cond.cc
@@ -34,6 +34,8 @@
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/strcat.h"
using xla::StatusOr;
@@ -41,6 +43,26 @@
namespace tensorflow {
namespace functionalize_cond {
+bool AncestorNode::operator<(const AncestorNode& other) const {
+ return (output_tensor.node->id() < other.output_tensor.node->id()) ||
+ (output_tensor.node->id() == other.output_tensor.node->id() &&
+ output_tensor.index < other.output_tensor.index) ||
+ (output_tensor.node->id() == other.output_tensor.node->id() &&
+ output_tensor.index == other.output_tensor.index &&
+ type < other.type);
+}
+
+bool AncestorNode::operator==(const AncestorNode& other) const {
+ return output_tensor.node->id() == other.output_tensor.node->id() &&
+ output_tensor.index == other.output_tensor.index && type == other.type;
+}
+
+size_t AncestorNode::Hash::operator()(const AncestorNode& ancestor) const {
+ size_t h = std::hash<int>()(ancestor.output_tensor.node->id());
+ h = Hash64Combine(h, std::hash<int>()(ancestor.output_tensor.index));
+ return Hash64Combine(h, std::hash<int>()(static_cast<int>(ancestor.type)));
+}
+
// TODO(jpienaar): Move to OutputTensor.
string DebugString(const OutputTensor& tensor) {
return absl::StrCat(tensor.node->name(), ":", tensor.index);
@@ -145,10 +167,10 @@
if (map.empty()) return 0;
// Compute hash of the front element.
auto it = map.begin();
- size_t h = hash<Node*>()(*it);
+ size_t h = AncestorNode::Hash()(*it);
for (++it; it != map.end(); ++it) {
// Combine the has with the different elements in the map.
- h = Hash64Combine(h, hash<Node*>()(*it));
+ h = Hash64Combine(h, AncestorNode::Hash()(*it));
}
return h;
}
@@ -229,7 +251,17 @@
}
string StateMap::AncestorStateToString(const Node* node) const {
- if (auto id = LookupAncestorId(node)) return NodesToString(*id);
+ if (auto id = LookupAncestorId(node)) {
+ return absl::StrCat(
+ "{",
+ absl::StrJoin(*id, ",",
+ [](string* output, const AncestorNode& ancestor) {
+ absl::StrAppend(output,
+ ancestor.output_tensor.node->name(),
+ ":", ancestor.output_tensor.index);
+ }),
+ "}");
+ }
return "{}";
}
@@ -967,6 +999,10 @@
VLOG(4) << "Joining (for merge) " << DebugString(src) << " and "
<< DebugString(dst);
if (state_map_.IsEmpty(dst)) return src;
+ if (state_map_.IsEmpty(src)) {
+ return errors::Internal("Merge node ", merge->name(),
+ " has input that's not in any CondContext.");
+ }
if (state_map_.IsDead(src)) return src;
if (state_map_.IsDead(dst)) return dst;
@@ -1201,8 +1237,17 @@
if (other_id != id && other_id != nullptr) {
state.insert(other_id->begin(), other_id->end());
}
- if (IsSwitch(src) || IsMerge(src)) {
- state.insert(src);
+ if (IsMerge(src)) {
+ state.insert({{src, 0}, AncestorNode::AncestorNodeType::kMerge});
+ } else if (IsSwitch(src)) {
+ OutputTensor pred;
+ // For dead switch nodes, GetSwitchPredicate() will fail, and we use
+ // the switch node directly as ancestor.
+ if (GetSwitchPredicate(*src, &pred).ok()) {
+ state.insert({pred, AncestorNode::AncestorNodeType::kPred});
+ } else {
+ state.insert({{src, 0}, AncestorNode::AncestorNodeType::kSwitch});
+ }
}
return state_map_.GetAncestorId(state);
};
diff --git a/tensorflow/compiler/tf2xla/functionalize_cond.h b/tensorflow/compiler/tf2xla/functionalize_cond.h
index 9a610a5..d85800f 100644
--- a/tensorflow/compiler/tf2xla/functionalize_cond.h
+++ b/tensorflow/compiler/tf2xla/functionalize_cond.h
@@ -43,6 +43,33 @@
kNeither = 3,
};
+// When we keep track of which switch/merge node's feed into a node, we record
+// 1) predicate for non-dead switch node,
+// 2) the switch node itself for dead switch node,
+// 3) the merge node itself for merge node.
+// Case 1) is an optimization. With this optimization, if there are nodes from
+// different switch nodes but those switch nodes have the same predicate, the
+// nodes will still have same AncestorState, and they will be clustered into a
+// single "If".
+struct AncestorNode {
+ enum class AncestorNodeType {
+ kPred = 0,
+ kSwitch = 1,
+ kMerge = 2,
+ };
+
+ OutputTensor output_tensor;
+ AncestorNodeType type;
+
+ // Compare two AncestorNodes by (node id, index, type).
+ bool operator<(const AncestorNode& other) const;
+ bool operator==(const AncestorNode& other) const;
+
+ struct Hash {
+ size_t operator()(const AncestorNode&) const;
+ };
+};
+
// StateMap is responsible for mapping from each graph Node to
// * a CondState, where each CondState is a map from predicate to branch (i,e.,
// what predicates have to hold or not hold).
@@ -68,7 +95,7 @@
using CondId = const CondState*;
// Keep track of which switch/merge node's feed into a node's values.
- using AncestorState = std::set<Node*>;
+ using AncestorState = std::set<AncestorNode>;
// Every unique ID is mapped to a AncestorState.
using AncestorId = const AncestorState*;
diff --git a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc
index b0aabd6..05fa1ee 100644
--- a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc
@@ -101,6 +101,17 @@
TF_EXPECT_OK(t.status());
}
+TEST_F(FunctionalizeCondTest, JoinCondStatesMergeWithInputNotInCondContext) {
+ Tensor val_tensor(DT_INT32, TensorShape());
+ val_tensor.flat<int>().setZero();
+ Node* val = test::graph::Constant(graph_.get(), val_tensor, "val");
+ Node* m = test::graph::Merge(graph_.get(), val, val);
+
+ StateMap::CondState cond_state;
+ auto joined_or = JoinCondStatesMerge(m, /*src=*/nullptr, &cond_state);
+ EXPECT_FALSE(joined_or.ok());
+}
+
} // namespace
} // namespace functionalize_cond
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/fused_batchnorm_reserve_space_test.cc b/tensorflow/compiler/tf2xla/fused_batchnorm_reserve_space_test.cc
new file mode 100644
index 0000000..4535ece
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/fused_batchnorm_reserve_space_test.cc
@@ -0,0 +1,130 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+#include <vector>
+
+#include "absl/algorithm/container.h"
+#include "tensorflow/cc/ops/array_ops.h"
+#include "tensorflow/cc/ops/const_op.h"
+#include "tensorflow/cc/ops/nn_ops.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/session.h"
+
+namespace tensorflow {
+namespace {
+Status GetTestDevice(Session* session, string* test_device) {
+ std::vector<DeviceAttributes> devices;
+ TF_RETURN_IF_ERROR(session->ListDevices(&devices));
+
+ bool found_cpu = absl::c_any_of(devices, [&](const DeviceAttributes& device) {
+ return device.device_type() == "CPU";
+ });
+
+ bool found_gpu = absl::c_any_of(devices, [&](const DeviceAttributes& device) {
+ return device.device_type() == "GPU";
+ });
+
+ if (!found_gpu && !found_cpu) {
+ return errors::Internal("Expected at least one CPU or GPU!");
+ }
+
+ *test_device = found_gpu ? "GPU" : "CPU";
+ VLOG(2) << "Using test device " << *test_device;
+ return Status::OK();
+}
+
+void FillZeros(Tensor* tensor) {
+ auto flat = tensor->flat<float>();
+ for (int i = 0; i < flat.size(); i++) {
+ flat.data()[i] = 0.0f;
+ }
+}
+
+// This tests check that the implementation outputs from FusedBatchnorm
+// training, reserve_space_{1|2}, are what we assume them to be in the TF/XLA
+// lowering.
+//
+// If this test starts failing then it doesn't indicate that TF/cudnn have
+// violated their contract, but it indicates that we need to update the TF/XLA
+// lowering for FusedBatchnorm training to match the new implementation defined
+// behavior.
+TEST(FusedBatchnormReserveSpaceTest, Test) {
+ using ::tensorflow::ops::Const;
+ using ::tensorflow::ops::FusedBatchNorm;
+
+ std::unique_ptr<tensorflow::Session> session(
+ tensorflow::NewSession(tensorflow::SessionOptions{}));
+
+ string test_device;
+ TF_ASSERT_OK(GetTestDevice(session.get(), &test_device));
+
+ Scope root = tensorflow::Scope::NewRootScope();
+ Output input = ops::Placeholder(root.WithOpName("input"), DT_FLOAT);
+
+ Tensor scale_data(DT_FLOAT, TensorShape({10}));
+ FillZeros(&scale_data);
+ Output scale =
+ Const(root.WithOpName("scale"), Input::Initializer(scale_data));
+
+ Tensor offset_data(DT_FLOAT, TensorShape({10}));
+ FillZeros(&offset_data);
+ Output offset =
+ Const(root.WithOpName("offset"), Input::Initializer(offset_data));
+
+ Tensor mean_data(DT_FLOAT, TensorShape({0}));
+ Output mean = Const(root.WithOpName("offset"), Input::Initializer(mean_data));
+
+ Tensor variance_data(DT_FLOAT, TensorShape({0}));
+ Output variance =
+ Const(root.WithOpName("variance"), Input::Initializer(variance_data));
+
+ string tf_device = absl::StrCat("/device:", test_device, ":0");
+ string xla_device = absl::StrCat("/device:XLA_", test_device, ":0");
+
+ FusedBatchNorm fused_batch_norm_tf(
+ root.WithOpName("fused_batch_norm_tf").WithDevice(tf_device), input,
+ scale, offset, mean, variance, FusedBatchNorm::Attrs{}.IsTraining(true));
+ FusedBatchNorm fused_batch_norm_xla(
+ root.WithOpName("fused_batch_norm_xla").WithDevice(xla_device), input,
+ scale, offset, mean, variance, FusedBatchNorm::Attrs{}.IsTraining(true));
+
+ tensorflow::GraphDef graph;
+ TF_ASSERT_OK(root.ToGraphDef(&graph));
+
+ TF_ASSERT_OK(session->Create(graph));
+
+ Tensor input_data(DT_FLOAT, TensorShape({10, 10, 10, 10}));
+ auto flat_input = input_data.flat<float>();
+ for (int i = 0; i < flat_input.size(); i++) {
+ flat_input.data()[i] = (i - 5) / 1000.0f;
+ }
+
+ std::vector<Tensor> results;
+ TF_ASSERT_OK(session->Run({{"input", input_data}},
+ {fused_batch_norm_tf.reserve_space_1.name(),
+ fused_batch_norm_xla.reserve_space_1.name(),
+ fused_batch_norm_tf.reserve_space_2.name(),
+ fused_batch_norm_xla.reserve_space_2.name()},
+ {}, &results));
+
+ test::ExpectClose(results[0], results[1], /*atol=*/1e-4);
+ test::ExpectClose(results[2], results[3], /*atol=*/1e-4);
+}
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index bfe668f..69353fe 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -147,6 +147,7 @@
"//tensorflow/compiler/xla/client/lib:triangular_solve",
"//tensorflow/core:bitwise_ops_op_lib",
"//tensorflow/core:control_flow_ops_op_lib",
+ "//tensorflow/core:data_flow_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:functional_ops_op_lib",
"//tensorflow/core:image_ops_op_lib",
@@ -155,10 +156,14 @@
"//tensorflow/core:list_ops_op_lib",
"//tensorflow/core:math_ops_op_lib",
"//tensorflow/core:nn_ops_op_lib",
+ "//tensorflow/core:no_op_op_lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:random_ops_op_lib",
"//tensorflow/core:resource_variable_ops_op_lib",
+ "//tensorflow/core:sendrecv_ops_op_lib",
+ "//tensorflow/core:sparse_ops_op_lib",
"//tensorflow/core:spectral_ops_op_lib",
+ "//tensorflow/core:state_ops_op_lib",
"//tensorflow/core:stateless_random_ops_op_lib",
"//tensorflow/core:training_ops_op_lib",
"//tensorflow/core/kernels:bounds_check",
diff --git a/tensorflow/compiler/tf2xla/kernels/arg_op.cc b/tensorflow/compiler/tf2xla/kernels/arg_op.cc
index 795ea09..5554d7a 100644
--- a/tensorflow/compiler/tf2xla/kernels/arg_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/arg_op.cc
@@ -53,7 +53,11 @@
const XlaExpression& arg = ctx->xla_context()->args()[index_];
OP_REQUIRES(ctx, arg.kind() != XlaExpression::Kind::kInvalid,
errors::InvalidArgument("Invalid/missing argument expression"));
- ctx->SetOutputExpression(0, arg);
+ if (ctx->expected_output_dtype(0) == DT_VARIANT) {
+ ctx->SetTensorListOutput(0, arg.handle());
+ } else {
+ ctx->SetOutputExpression(0, arg);
+ }
}
private:
@@ -63,6 +67,8 @@
TF_DISALLOW_COPY_AND_ASSIGN(XlaArgOp);
};
-REGISTER_XLA_OP(Name("_Arg").AllowResourceTypes().CompilationOnly(), XlaArgOp);
+REGISTER_XLA_OP(
+ Name("_Arg").AllowResourceTypes().AllowVariantTypes().CompilationOnly(),
+ XlaArgOp);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc
index 83c6a65..f1d78c8 100644
--- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc
@@ -18,6 +18,7 @@
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/util/tensor_format.h"
@@ -35,6 +36,7 @@
OP_REQUIRES(
ctx, FormatFromString(data_format_str, &data_format_),
errors::InvalidArgument("Invalid data format: ", data_format_str));
+ is_on_gpu_ = ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT;
}
void Compile(XlaOpKernelContext* ctx) override {
@@ -72,7 +74,18 @@
// variance to the gradient. Here we maintain the same behavior by setting
// them to the mean and variance calculated by BatchNormTraining.
ctx->SetOutput(3, xla::GetTupleElement(output, 1));
- ctx->SetOutput(4, xla::GetTupleElement(output, 2));
+ if (is_on_gpu_) {
+ // The last two outputs from the FusedBatchNorm training TensorFlow GPU
+ // op are implementation defined. For now we rely on the in-practice
+ // behavior of the op:
+ // output 3 is the mean
+ // output 4 is rsqrt(variance + epsilon)
+ xla::XlaOp variance = xla::GetTupleElement(output, 2);
+ ctx->SetOutput(4, xla::Rsqrt(xla::Add(
+ variance, xla::ScalarLike(variance, epsilon_))));
+ } else {
+ ctx->SetOutput(4, xla::GetTupleElement(output, 2));
+ }
} else {
xla::XlaOp output = xla::BatchNormInference(
input, ctx->Input(1), ctx->Input(2), ctx->Input(3), ctx->Input(4),
@@ -90,6 +103,7 @@
float epsilon_;
TensorFormat data_format_;
bool is_training_;
+ bool is_on_gpu_;
};
REGISTER_XLA_OP(Name("FusedBatchNorm"), FusedBatchNormOp);
@@ -105,6 +119,7 @@
OP_REQUIRES(
ctx, FormatFromString(data_format_str, &data_format_),
errors::InvalidArgument("Invalid data format: ", data_format_str));
+ is_on_gpu_ = ctx->device_type().type_string() == DEVICE_GPU_XLA_JIT;
}
void Compile(XlaOpKernelContext* ctx) override {
@@ -131,6 +146,22 @@
xla::XlaOp scale_backprop;
xla::XlaOp offset_backprop;
if (is_training_) {
+ if (is_on_gpu_) {
+ // The last two inputs to the FusedBatchNormGrad training TensorFlow GPU
+ // op are implementation defined. For now we rely on the in-practice
+ // behavior of the op: input 3 is the mean input 4 is rsqrt(variance +
+ // epsilon)
+ //
+ // The XLA op expects:
+ // input 3 is the mean
+ // input 4 is the variance
+ //
+ // so we adjust input 4 here.
+ xla::XlaOp one = xla::ScalarLike(var, 1.0f);
+ xla::XlaOp epsilon = xla::ScalarLike(var, epsilon_);
+ var = xla::Sub(one / (var * var), epsilon);
+ }
+
xla::XlaOp output =
xla::BatchNormGrad(activations, scale, mean, var, grad_backprop,
epsilon_, feature_index);
@@ -187,6 +218,7 @@
TensorFormat data_format_;
float epsilon_;
bool is_training_;
+ bool is_on_gpu_;
};
REGISTER_XLA_OP(Name("FusedBatchNormGrad"), FusedBatchNormGradOp);
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc
index 5f99b24..e8b270c 100644
--- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc
+++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc
@@ -203,7 +203,8 @@
StringPiece label, int num_spatial_dims, const xla::Shape& input_shape,
const xla::Shape& filter_shape, const xla::Shape& out_backprop_shape,
absl::Span<const int32> dilations, const std::vector<int32>& strides,
- Padding padding, TensorFormat data_format, ConvBackpropDimensions* dims) {
+ Padding padding, TensorFormat data_format, ConvBackpropDimensions* dims,
+ absl::Span<const int64> explicit_paddings) {
TensorShape input_tensor_shape, filter_tensor_shape,
out_backprop_tensor_shape;
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(input_shape, &input_tensor_shape));
@@ -212,8 +213,8 @@
XLAShapeToTensorShape(out_backprop_shape, &out_backprop_tensor_shape));
return ConvBackpropComputeDimensionsV2(
label, num_spatial_dims, input_tensor_shape, filter_tensor_shape,
- out_backprop_tensor_shape, dilations, strides, padding,
- /*explicit_paddings=*/{}, data_format, dims);
+ out_backprop_tensor_shape, dilations, strides, padding, explicit_paddings,
+ data_format, dims);
}
} // anonymous namespace
@@ -227,10 +228,9 @@
TF_RETURN_IF_ERROR(ctx->GetAttr("dilations", &attrs.dilations));
TF_RETURN_IF_ERROR(ctx->GetAttr("strides", &attrs.strides));
TF_RETURN_IF_ERROR(ctx->GetAttr("padding", &attrs.padding));
- // TODO(reedwm): Support explicit padding.
if (attrs.padding == EXPLICIT) {
- return errors::Unimplemented(
- "XLA does not yet support Conv2D with explicit padding.");
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("explicit_paddings", &attrs.explicit_paddings));
}
string data_format;
@@ -303,6 +303,11 @@
window_strides[i] = attrs.strides.at(dim);
rhs_dilation[i] = attrs.dilations.at(dim);
+ if (attrs.padding == EXPLICIT) {
+ padding[i] = {attrs.explicit_paddings.at(dim * 2),
+ attrs.explicit_paddings.at(dim * 2 + 1)};
+ }
+
int64 unused_output_size;
TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerboseV2(
input_shape.dimensions(dim), filter_shape.dimensions(i),
@@ -337,7 +342,7 @@
TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
type_string, attrs.num_spatial_dims, input_shape, expanded_filter_shape,
out_backprop_shape, attrs.dilations, attrs.strides, attrs.padding,
- attrs.data_format, &dims));
+ attrs.data_format, &dims, attrs.explicit_paddings));
// The input gradients are computed by a convolution of the output
// gradients and the filter, with some appropriate padding. See the
@@ -420,7 +425,7 @@
TF_RETURN_IF_ERROR(ConvBackpropComputeDimensionsV2XlaShapes(
type_string, attrs.num_spatial_dims, activations_shape,
expanded_filter_shape, out_backprop_shape, attrs.dilations, attrs.strides,
- attrs.padding, attrs.data_format, &dims));
+ attrs.padding, attrs.data_format, &dims, attrs.explicit_paddings));
// The activations (inputs) form the LHS of the convolution.
// Activations have shape: [batch, in_rows, in_cols, ..., in_depth]
@@ -469,6 +474,8 @@
int64 dim = GetTensorSpatialDimIndex(num_dims, attrs.data_format, i);
dnums.add_input_spatial_dimensions(dim);
dnums.add_kernel_spatial_dimensions(dim);
+ rhs_dilation[i] = dims.spatial_dims[i].stride;
+ window_strides[i] = attrs.dilations[dim];
// We will also need to pad the input with zeros such that after the
// convolution, we get the right size for the filter.
@@ -495,6 +502,8 @@
// We apply negative padding in this case.
const int64 pad_total = padded_in_size - dims.spatial_dims[i].input_size;
+ // + For the EXPLICIT padding, we pad the top/left side with the explicit
+ // padding and pad the bottom/right side with the remaining space.
// + For the VALID padding, we don't pad anything on the top/left side
// and pad the bottom/right side with the remaining space.
// + For the SAME padding, we pad top/left side the same as bottom/right
@@ -503,12 +512,12 @@
// In addition, if the padded input size is smaller than the input size,
// we need to ignore some training elements of the input. We do this by
// applying negative padding on the right/bottom.
- const int64 pad_before =
- attrs.padding == Padding::SAME ? std::max<int64>(pad_total / 2, 0) : 0;
-
+ const int64 pad_before = attrs.padding == Padding::EXPLICIT
+ ? attrs.explicit_paddings[2 * dim]
+ : attrs.padding == Padding::SAME
+ ? std::max<int64>(pad_total / 2, 0)
+ : 0;
padding[i] = {pad_before, pad_total - pad_before};
- rhs_dilation[i] = dims.spatial_dims[i].stride;
- window_strides[i] = attrs.dilations[dim];
}
// Besides padding the input, we will also expand output_rows to
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h
index 6e1b70a..d893eca 100644
--- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h
+++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.h
@@ -47,6 +47,7 @@
std::vector<int32> dilations;
std::vector<int32> strides;
Padding padding;
+ std::vector<int64> explicit_paddings;
TensorFormat data_format;
};
diff --git a/tensorflow/compiler/tf2xla/kernels/identity_op.cc b/tensorflow/compiler/tf2xla/kernels/identity_op.cc
index 19dd38c..8b27e8e 100644
--- a/tensorflow/compiler/tf2xla/kernels/identity_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/identity_op.cc
@@ -38,9 +38,13 @@
// XLA_* devices also register a "real" Identity operator so we suppress the
// dummy operator using CompilationOnly().
-REGISTER_XLA_OP(Name("Identity").AllowResourceTypes().CompilationOnly(),
- IdentityOp);
-REGISTER_XLA_OP(Name("IdentityN").AllowResourceTypes().CompilationOnly(),
+REGISTER_XLA_OP(
+ Name("Identity").AllowResourceTypes().AllowVariantTypes().CompilationOnly(),
+ IdentityOp);
+REGISTER_XLA_OP(Name("IdentityN")
+ .AllowResourceTypes()
+ .AllowVariantTypes()
+ .CompilationOnly(),
IdentityOp);
REGISTER_XLA_OP(Name("PlaceholderWithDefault"), IdentityOp);
REGISTER_XLA_OP(Name("PreventGradient"), IdentityOp);
diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc
index 90c0ebe..b322495 100644
--- a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc
@@ -31,7 +31,8 @@
void Compile(XlaOpKernelContext* ctx) override {
auto result = xla::TriangularSolve(
ctx->Input(0), ctx->Input(1), /*left_side=*/true,
- /*lower=*/lower_, /*transpose_a=*/adjoint_, /*conjugate_a=*/adjoint_);
+ /*lower=*/lower_, /*transpose_a=*/adjoint_, /*conjugate_a=*/adjoint_,
+ /*unit_diagonal=*/false);
ctx->SetOutput(0, result);
}
diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc
index e4046c7..1f41703 100644
--- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc
@@ -37,10 +37,14 @@
void Compile(XlaOpKernelContext* ctx) override {
const Tensor& input = ctx->op_kernel_context()->input(0);
- OP_REQUIRES(ctx, input.dtype() == dtype_,
- errors::InvalidArgument(
- "Type mismatch: actual ", DataTypeString(input.dtype()),
- " vs. expect ", DataTypeString(dtype_)));
+ // DT_VARIANT types represent Tensor Lists and are wrapped in a DT_UINT8
+ // tensor so we skip the check here.
+ if (dtype_ != DT_VARIANT) {
+ OP_REQUIRES(ctx, input.dtype() == dtype_,
+ errors::InvalidArgument(
+ "Type mismatch: actual ", DataTypeString(input.dtype()),
+ " vs. expect ", DataTypeString(dtype_)));
+ }
auto frame = ctx->call_frame();
if (frame) {
// If 'frame' is non-null, this is an inner function call inside a JIT
@@ -59,8 +63,9 @@
TF_DISALLOW_COPY_AND_ASSIGN(RetvalOp);
};
-REGISTER_XLA_OP(Name("_Retval").AllowResourceTypes().CompilationOnly(),
- RetvalOp);
+REGISTER_XLA_OP(
+ Name("_Retval").AllowResourceTypes().AllowVariantTypes().CompilationOnly(),
+ RetvalOp);
} // anonymous namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
index a7cc8c1..62b5cd3 100644
--- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc
@@ -89,8 +89,9 @@
}
XLAJIT_MAKE_UNARY(Sigmoid, Sigmoid(x));
-// Returns 0 if x is 0, -1 if x < 0 and 1 if x > 0.
-XLAJIT_MAKE_UNARY(Sign, xla::Sign(x));
+// Returns 0 if x is NaN, 0 if x is 0, -1 if x < 0 and 1 if x > 0.
+XLAJIT_MAKE_UNARY(Sign,
+ xla::Select(xla::Ne(x, x), xla::ZerosLike(x), xla::Sign(x)));
XLAJIT_MAKE_UNARY(Sinh, xla::Sinh(x));
// softplus(x) = log(1 + exp(x))
@@ -113,37 +114,11 @@
XLAJIT_MAKE_UNARY(Real, xla::Real(x));
XLAJIT_MAKE_UNARY(Imag, xla::Imag(x));
+XLAJIT_MAKE_UNARY(Erf, xla::Erf(x));
+XLAJIT_MAKE_UNARY(Erfc, xla::Erfc(x));
#undef XLAJIT_MAKE_UNARY
-// Erf/Erfc. For x in (-1, 1), the erf approximation is used; erfc polynomial
-// is used outside of this range.
-class ErfOp : public XlaOpKernel {
- public:
- explicit ErfOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
- void Compile(XlaOpKernelContext* ctx) override {
- xla::XlaOp x = ctx->Input(0);
- xla::XlaOp one = xla::ScalarLike(x, 1.0);
- auto y =
- xla::Select(xla::Gt(xla::Abs(x), one), one - xla::Erfc(x), xla::Erf(x));
- ctx->SetOutput(0, y);
- }
-};
-REGISTER_XLA_OP(Name("Erf"), ErfOp);
-
-class ErfcOp : public XlaOpKernel {
- public:
- explicit ErfcOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
- void Compile(XlaOpKernelContext* ctx) override {
- xla::XlaOp x = ctx->Input(0);
- xla::XlaOp one = xla::ScalarLike(x, 1.0);
- auto y =
- xla::Select(xla::Lt(xla::Abs(x), one), one - xla::Erf(x), xla::Erfc(x));
- ctx->SetOutput(0, y);
- }
-};
-REGISTER_XLA_OP(Name("Erfc"), ErfcOp);
-
class LgammaOp : public XlaOpKernel {
public:
explicit LgammaOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc
index fd5ff10..f49da96 100644
--- a/tensorflow/compiler/tf2xla/kernels/while_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc
@@ -348,8 +348,11 @@
VLOG(1) << "Done building while loop";
}
-REGISTER_XLA_OP(Name("While").AllowResourceTypes(), XlaWhileOp);
-REGISTER_XLA_OP(Name("StatelessWhile").AllowResourceTypes(), XlaWhileOp);
-REGISTER_XLA_OP(Name("XlaWhile").AllowResourceTypes(), XlaWhileOp);
+REGISTER_XLA_OP(Name("While").AllowResourceTypes().AllowVariantTypes(),
+ XlaWhileOp);
+REGISTER_XLA_OP(Name("StatelessWhile").AllowResourceTypes().AllowVariantTypes(),
+ XlaWhileOp);
+REGISTER_XLA_OP(Name("XlaWhile").AllowResourceTypes().AllowVariantTypes(),
+ XlaWhileOp);
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 1f9cfcd..0833264 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -43,6 +43,8 @@
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/lib/core/error_codes.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/logging.h"
@@ -58,7 +60,11 @@
" elements while function has ", types.size());
}
for (int i = 0; i < types.size(); ++i) {
- if (types[i] != args[i].type && types[i] != DT_RESOURCE) {
+ // Don't perform type checks on resource variables and tensor
+ // lists (DT_VARIANT) as we have to trick the type system in order to
+ // plumb them through. DT_VARIANTS are wrapped in a DT_UINT8 tensor.
+ if (types[i] != args[i].type && types[i] != DT_RESOURCE &&
+ types[i] != DT_VARIANT) {
return errors::Internal(
"Argument ", i, " has declared type ", DataTypeString(args[i].type),
" but function parameter has type ", DataTypeString(types[i]));
@@ -1104,8 +1110,17 @@
result->outputs.resize(context->retvals().size());
std::vector<XlaExpression> retvals = context->retvals();
if (options.resolve_compile_time_constants) {
- TF_RETURN_IF_ERROR(ResolveConstantExpressionsToConstants(
- client(), absl::Span<XlaExpression>(retvals)));
+ Status status = ResolveConstantExpressionsToConstants(
+ client(), absl::Span<XlaExpression>(retvals));
+
+ // If the HloEvaluator has not implemented an expression, just evaluate it
+ // at runtime.
+ if (status.code() == error::UNIMPLEMENTED) {
+ ConvertConstantsToExpressions(&builder,
+ absl::Span<XlaExpression>(retvals));
+ } else {
+ TF_RETURN_IF_ERROR(status);
+ }
} else {
ConvertConstantsToExpressions(&builder, absl::Span<XlaExpression>(retvals));
}
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc
index 14237df..2631403 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc
@@ -73,6 +73,11 @@
<< " have incompatible allow_resource_types settings.";
return false;
}
+ if (x.allow_variant_types != y.allow_variant_types) {
+ LOG(WARNING) << "Registrations of " << x.name
+ << " have incompatible allow_variant_types settings.";
+ return false;
+ }
if (!x.has_device_whitelist && !y.has_device_whitelist) {
LOG(WARNING) << "Duplicate registrations of " << x.name
<< "with no device whitelists.";
@@ -289,6 +294,9 @@
if (op_registration->allow_resource_types) {
allowed_values->add_type(DT_RESOURCE);
}
+ if (op_registration->allow_variant_types) {
+ allowed_values->add_type(DT_VARIANT);
+ }
// Don't build KernelDefs that have unsatisfiable type constraints.
if (allowed_values->type().empty()) {
unsatisfiable_type_constraint = true;
@@ -485,6 +493,11 @@
return *this;
}
+XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::AllowVariantTypes() {
+ registration_->allow_variant_types = true;
+ return *this;
+}
+
XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint(
absl::string_view attr_name, DataType allowed) {
std::set<DataType>& types =
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h
index ce3b6b2..c5e078a 100644
--- a/tensorflow/compiler/tf2xla/xla_op_registry.h
+++ b/tensorflow/compiler/tf2xla/xla_op_registry.h
@@ -212,6 +212,10 @@
// allow DT_RESOURCE.
bool allow_resource_types = false;
+ // Should we allow variant types for type attributes? Used by While to
+ // allow TensorList which is of type DT_VARIANT.
+ bool allow_variant_types = false;
+
// Mapping from attribute name to a list of supported types.
std::unordered_map<string, std::set<DataType>> type_constraints;
@@ -233,9 +237,9 @@
// Returns true if registrations x and y can both be added to the registry.
// This is always the case if they refer to different ops. If they refer to
- // the same op name, they must: have the same values for compilation_only and
- // allow_resource_types; use a device_whitelist; and their
- // whitelists must not intersect.
+ // the same op name, they must: have the same values for compilation_only,
+ // allow_resource_types and allow_variant_types; use a device_whitelist; and
+ // their whitelists must not intersect.
static bool IsCompatible(const OpRegistration& x, const OpRegistration& y);
static Status CompileTimeConstantInputs(const NodeDef& node_def,
@@ -293,6 +297,9 @@
// Allow DT_RESOURCE types for type parameters.
XlaOpRegistrationBuilder& AllowResourceTypes();
+ // Allow DT_VARIANT types for type parameters.
+ XlaOpRegistrationBuilder& AllowVariantTypes();
+
// Mark 'input_name' as an argument whose value must be known at compile-time.
XlaOpRegistrationBuilder& CompileTimeConstantInput(
absl::string_view input_name);
diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD
index b30ab84..6be8154 100644
--- a/tensorflow/compiler/xla/client/lib/BUILD
+++ b/tensorflow/compiler/xla/client/lib/BUILD
@@ -494,3 +494,50 @@
"//tensorflow/core:test",
],
)
+
+cc_library(
+ name = "self_adjoint_eigen",
+ srcs = ["self_adjoint_eigen.cc"],
+ hdrs = ["self_adjoint_eigen.h"],
+ deps = [
+ ":arithmetic",
+ ":constants",
+ ":loops",
+ ":math",
+ ":matrix",
+ ":slicing",
+ "//tensorflow/compiler/xla:literal_util",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:xla_builder",
+ "//tensorflow/core:lib",
+ ],
+)
+
+xla_test(
+ name = "self_adjoint_eigen_test",
+ size = "medium",
+ srcs = ["self_adjoint_eigen_test.cc"],
+ real_hardware_only = True,
+ shard_count = 10,
+ tags = ["optonly"],
+ deps = [
+ ":arithmetic",
+ ":constants",
+ ":matrix",
+ ":self_adjoint_eigen",
+ "//tensorflow/compiler/xla:array2d",
+ "//tensorflow/compiler/xla:array3d",
+ "//tensorflow/compiler/xla:literal",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/client:xla_builder",
+ "//tensorflow/compiler/xla/tests:client_library_test_base",
+ "//tensorflow/compiler/xla/tests:literal_test_util",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:test",
+ ],
+)
diff --git a/tensorflow/compiler/xla/client/lib/cholesky.cc b/tensorflow/compiler/xla/client/lib/cholesky.cc
index 414bd14..4578fff 100644
--- a/tensorflow/compiler/xla/client/lib/cholesky.cc
+++ b/tensorflow/compiler/xla/client/lib/cholesky.cc
@@ -199,6 +199,7 @@
/*lower=*/true,
/*transpose_a=*/true,
/*conjugate_a=*/false,
+ /*unit_diagonal=*/false,
/*block_size=*/block_size);
l = UpdateSliceInMinorDims(l, update, {i + k, i});
}
diff --git a/tensorflow/compiler/xla/client/lib/constants.cc b/tensorflow/compiler/xla/client/lib/constants.cc
index 1ada7b4..d0efd79 100644
--- a/tensorflow/compiler/xla/client/lib/constants.cc
+++ b/tensorflow/compiler/xla/client/lib/constants.cc
@@ -100,4 +100,28 @@
}
}
+XlaOp NanValue(XlaBuilder* builder, PrimitiveType type) {
+ return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ switch (type) {
+ case F16:
+ return ConstantR0<Eigen::half>(
+ builder, Eigen::NumTraits<Eigen::half>::quiet_NaN());
+ case BF16:
+ return ConstantR0<bfloat16>(
+ builder, bfloat16(std::numeric_limits<float>::quiet_NaN()));
+ case F32:
+ return ConstantR0<float>(builder,
+ std::numeric_limits<float>::quiet_NaN());
+ case F64:
+ return ConstantR0<double>(builder,
+ std::numeric_limits<double>::quiet_NaN());
+ default:
+ return InvalidArgument(
+ "Operand to NanValue was %s, but must be a real-valued "
+ "floating-point type.",
+ PrimitiveType_Name(type));
+ }
+ });
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/constants.h b/tensorflow/compiler/xla/client/lib/constants.h
index 4e5310a..77e7ca6 100644
--- a/tensorflow/compiler/xla/client/lib/constants.h
+++ b/tensorflow/compiler/xla/client/lib/constants.h
@@ -142,6 +142,9 @@
// Returns the maximum representable finite value for 'type'.
XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type);
+// Returns a nan for the given type. Only valid for real-valued fp types.
+XlaOp NanValue(XlaBuilder* builder, PrimitiveType type);
+
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONSTANTS_H_
diff --git a/tensorflow/compiler/xla/client/lib/constants_test.cc b/tensorflow/compiler/xla/client/lib/constants_test.cc
index f4320f6..180175b 100644
--- a/tensorflow/compiler/xla/client/lib/constants_test.cc
+++ b/tensorflow/compiler/xla/client/lib/constants_test.cc
@@ -155,5 +155,12 @@
{});
}
+XLA_TEST_F(ConstantsTest, NanValueF32) {
+ XlaBuilder builder(TestName());
+ NanValue(&builder, F32);
+ ComputeAndCompareR0<float>(&builder, std::numeric_limits<float>::quiet_NaN(),
+ {});
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/math.cc b/tensorflow/compiler/xla/client/lib/math.cc
index 6758840..19d98d1 100644
--- a/tensorflow/compiler/xla/client/lib/math.cc
+++ b/tensorflow/compiler/xla/client/lib/math.cc
@@ -79,6 +79,34 @@
});
}
+XlaOp IsNegZero(XlaOp operand) {
+ auto& b = *operand.builder();
+ return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("IsNegZero", operand));
+ TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(operand));
+
+ // The bitwise representation of -0 in bfloat16 and IEEE 754 is 0x80...0
+ // (sign bit on, all other bits off).
+ switch (shape.element_type()) {
+ case F64:
+ return Eq(BitcastConvertType(operand, U64),
+ ConstantR0WithType(&b, U64, uint64{1} << 63));
+ case F32:
+ return Eq(BitcastConvertType(operand, U32),
+ ConstantR0WithType(&b, U32, uint32{1} << 31));
+ case F16:
+ case BF16:
+ // Not all XLA backends handle U16 well, so we convert to F32/U32.
+ // TODO(jlebar): It would be nice if we could stay in (B)F16/U16 for
+ // backends that *do* support it.
+ return Eq(BitcastConvertType(ConvertElementType(operand, F32), U32),
+ ConstantR0WithType(&b, U32, uint32{1} << 31));
+ default:
+ LOG(FATAL) << "Expected real fp type.";
+ }
+ });
+}
+
XlaOp Sqrt(XlaOp operand) { return Pow(operand, ScalarLike(operand, 0.5)); }
XlaOp Rsqrt(XlaOp operand) { return Pow(operand, ScalarLike(operand, -0.5)); }
@@ -87,44 +115,6 @@
XlaOp Reciprocal(XlaOp operand) { return ScalarLike(operand, 1.0) / operand; }
-namespace {
-
-// Polynomials for computing erf/erfc. Originally from cephes.
-// Note we use float for compatibility across devices, at the cost of some
-// precision for 64 bit computations.
-//
-// Coefficients are in descending order.
-std::array<float, 9> kErfcPCoefficient = {
- 2.46196981473530512524E-10, 5.64189564831068821977E-1,
- 7.46321056442269912687E0, 4.86371970985681366614E1,
- 1.96520832956077098242E2, 5.26445194995477358631E2,
- 9.34528527171957607540E2, 1.02755188689515710272E3,
- 5.57535335369399327526E2};
-std::array<float, 9> kErfcQCoefficient = {
- 1.00000000000000000000E0, 1.32281951154744992508E1,
- 8.67072140885989742329E1, 3.54937778887819891062E2,
- 9.75708501743205489753E2, 1.82390916687909736289E3,
- 2.24633760818710981792E3, 1.65666309194161350182E3,
- 5.57535340817727675546E2};
-std::array<float, 6> kErfcRCoefficient = {
- 5.64189583547755073984E-1, 1.27536670759978104416E0,
- 5.01905042251180477414E0, 6.16021097993053585195E0,
- 7.40974269950448939160E0, 2.97886665372100240670E0};
-std::array<float, 7> kErfcSCoefficient = {
- 1.00000000000000000000E0, 2.26052863220117276590E0,
- 9.39603524938001434673E0, 1.20489539808096656605E1,
- 1.70814450747565897222E1, 9.60896809063285878198E0,
- 3.36907645100081516050E0};
-std::array<float, 5> kErfTCoefficient = {
- 9.60497373987051638749E0, 9.00260197203842689217E1,
- 2.23200534594684319226E3, 7.00332514112805075473E3,
- 5.55923013010394962768E4};
-std::array<float, 6> kErfUCoefficient = {
- 1.00000000000000000000E0, 3.35617141647503099647E1,
- 5.21357949780152679795E2, 4.59432382970980127987E3,
- 2.26290000613890934246E4, 4.92673942608635921086E4};
-} // namespace
-
// Evaluate the polynomial given coefficients and `x`.
// N.B. Coefficients should be supplied in decreasing order.
XlaOp EvaluatePolynomial(XlaOp x, absl::Span<const float> coefficients) {
@@ -135,73 +125,85 @@
return poly;
}
-// Compute an approximation of the error function complement (1 - erf(x)).
+// Computes an approximation of the error function complement (1 - erf(x)).
//
-// TODO(jlebar): This is not particularly efficient. The implementation in
-// Cephes that this follows was written for double precision, but our
-// coefficients are specified only to single-precision! Cephes has a different,
-// simpler implementation for single-precision.
+// Precondition: abs(x) >= 1. Otherwise, use ErfImpl.
//
-// Furthermore, we could simplify this further for f16 -- for example, because
-// exp(-4.2 * 4.2) = 0 (f16), the computations in service of the x < 8.0 branch
-// below are unnecessary.
+// This follows Cephes's f32 implementation of erfc, and so it may have errors
+// for double precision.
//
// See also these alternate implementations of erf and erfc:
//
// https://stackoverflow.com/questions/35148198
// https://stackoverflow.com/questions/35966695
//
+static XlaOp ErfcImpl(XlaOp x) {
+ // Coefficients for erfc(f32), from Cephes.
+ //
+ // erfc(x) = exp(-x^2) P(1/x), 1 < x < 2
+ static std::array<float, 9> kErfcPCoefficient{
+ +2.326819970068386E-2, -1.387039388740657E-1, +3.687424674597105E-1,
+ -5.824733027278666E-1, +6.210004621745983E-1, -4.944515323274145E-1,
+ +3.404879937665872E-1, -2.741127028184656E-1, +5.638259427386472E-1,
+ };
+ // erfc(x) = exp(-x^2) 1/x P(1/x^2), 2 < x < 14
+ static std::array<float, 8> kErfcRCoefficient{
+ -1.047766399936249E+1, +1.297719955372516E+1, -7.495518717768503E+0,
+ +2.921019019210786E+0, -1.015265279202700E+0, +4.218463358204948E-1,
+ -2.820767439740514E-1, +5.641895067754075E-1,
+ };
+
+ XlaOp abs_x = Abs(x);
+ XlaOp z = Exp(-x * x);
+ XlaOp q = ScalarLike(x, 1) / abs_x;
+ XlaOp y = q * q;
+ XlaOp p = Select(Lt(abs_x, ScalarLike(x, 2.0)),
+ EvaluatePolynomial(y, kErfcPCoefficient),
+ EvaluatePolynomial(y, kErfcRCoefficient));
+ y = z * q * p;
+ return Select(Lt(x, ScalarLike(x, 0)), ScalarLike(x, 2.0) - y, y);
+}
+
+// Compute a polynomial approximation of the error function.
+//
+// Precondition: abs(x) <= 1. Otherwise, use ErfcImpl.
+//
+// This follows Cephes's f32 implementation of erf, so it may have errors for
+// double precision.
+static XlaOp ErfImpl(XlaOp x) {
+ // Coefficients for by erf(f32), from Cephes.
+ //
+ // erf(x) = x P(x^2), 0 < x < 1
+ static std::array<float, 7> kErfTCoefficient{
+ +7.853861353153693E-5, -8.010193625184903E-4, +5.188327685732524E-3,
+ -2.685381193529856E-2, +1.128358514861418E-1, -3.761262582423300E-1,
+ +1.128379165726710E+0,
+ };
+
+ return x * EvaluatePolynomial(x * x, kErfTCoefficient);
+}
+
XlaOp Erfc(XlaOp x) {
auto& b = *x.builder();
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Erfc", x));
-
- XlaOp abs_x = Abs(x);
- XlaOp z = Exp(-x * x);
-
- XlaOp pp = EvaluatePolynomial(abs_x, kErfcPCoefficient);
- XlaOp pq = EvaluatePolynomial(abs_x, kErfcQCoefficient);
- XlaOp pr = EvaluatePolynomial(abs_x, kErfcRCoefficient);
- XlaOp ps = EvaluatePolynomial(abs_x, kErfcSCoefficient);
-
- XlaOp abs_x_small = Lt(abs_x, ScalarLike(x, 8.0));
- XlaOp y = Select(abs_x_small, z * pp / pq, z * pr / ps);
- XlaOp result_no_underflow =
- Select(Lt(x, ScalarLike(x, 0.0)), ScalarLike(x, 2.0) - y, y);
-
- // Check for edge cases, namely, exp(-x^2) is exactly 0, or the appropriate
- // denominator (ps or pq) is inf. (The check for exp(-x^2) == 0 is
- // necessary only for x == +/- inf, where this check lets us avoid
- // multiplying 0 by inf and getting nan.)
- auto is_pos_inf = [](XlaOp op) {
- return And(Not(IsFinite(op)), Gt(op, ScalarLike(op, 0)));
- };
- XlaOp underflow =
- Or(Eq(z, ScalarLike(z, 0)), Or(And(is_pos_inf(pq), abs_x_small),
- And(is_pos_inf(ps), Not(abs_x_small))));
- XlaOp result_underflow =
- Select(Lt(x, ScalarLike(x, 0)), FullLike(x, 2), FullLike(x, 0));
-
- return Select(underflow, result_underflow, result_no_underflow);
+ // erfc(x) =
+ // erfc_impl(x) if x > 1
+ // 1 - erf_impl(x) otherwise
+ return Select(Gt(Abs(x), ScalarLike(x, 1)), ErfcImpl(x),
+ ScalarLike(x, 1) - ErfImpl(x));
});
}
-// Compute a polynomial approximation of the error function.
XlaOp Erf(XlaOp x) {
auto& b = *x.builder();
return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
- // Reject non-real non-fp inputs. (We could extend erf to accept complex
- // types, but it doesn't seem necessary at this point.)
- TF_ASSIGN_OR_RETURN(auto shape, b.GetShape(x));
- if (!ShapeUtil::ElementIsFloating(shape)) {
- return InvalidArgument(
- "erf only accepts real floating-point arrays or scalars, but got %s",
- shape.ToString());
- }
- XlaOp z = x * x;
- XlaOp pt = EvaluatePolynomial(z, kErfTCoefficient);
- XlaOp pu = EvaluatePolynomial(z, kErfUCoefficient);
- return x * pt / pu;
+ TF_RETURN_IF_ERROR(EnsureOperandIsRealFp("Erf", x));
+ // erf(x) =
+ // erf_impl(x) if x < 1
+ // 1 - erfc_impl(x) otherwise
+ return Select(Lt(Abs(x), ScalarLike(x, 1)), ErfImpl(x),
+ ScalarLike(x, 1) - ErfcImpl(x));
});
}
diff --git a/tensorflow/compiler/xla/client/lib/math.h b/tensorflow/compiler/xla/client/lib/math.h
index 907571c..b036fa2 100644
--- a/tensorflow/compiler/xla/client/lib/math.h
+++ b/tensorflow/compiler/xla/client/lib/math.h
@@ -28,6 +28,11 @@
XlaOp IsInf(XlaOp operand);
XlaOp IsNan(XlaOp operand);
+// Determines whether operand is equal to -0.
+//
+// Raises an error for integral or complex values.
+XlaOp IsNegZero(XlaOp operand);
+
// Returns the next number after 'from' in the direction of 'to' the same way
// std::nextafter(from, to) would.
XlaOp NextAfter(XlaOp from, XlaOp to);
diff --git a/tensorflow/compiler/xla/client/lib/math_exhaustive_test.cc b/tensorflow/compiler/xla/client/lib/math_exhaustive_test.cc
index 0fb13a7..f4ed0db 100644
--- a/tensorflow/compiler/xla/client/lib/math_exhaustive_test.cc
+++ b/tensorflow/compiler/xla/client/lib/math_exhaustive_test.cc
@@ -71,7 +71,7 @@
XlaOp (*op)(XlaOp);
float (*host_op)(float);
- ErrorSpec error{0.01};
+ ErrorSpec error{0.01, 0.01};
// If true, don't test +/-infinity or negative 0.
bool skip_pos_inf = false;
@@ -145,7 +145,6 @@
// Testcase{"asinh", Asinh, std::asinh},
// Testcase{"sinh", Sinh, std::sinh},
// Testcase{"cosh", Cosh, std::cosh}.set_fewer_infs_ok(),
-// Testcase{"erf", Erf, std::erf},
// Testcase{"round_to_even", RoundToEven,
// [](float x) { return std::nearbyint(x / 2) * 2; }},
//
@@ -161,7 +160,8 @@
//
// TODO(b/123355973): Test math functions not from math.cc (e.g. log).
// TODO(b/123355973): Test bf16 and f32.
-//
+// TODO(b/123355973): Get rid of skip_infs / skip_neg_zero below if possible.
+// TODO(b/123355973): Reduce lgamma error if possible; it is very high.
INSTANTIATE_TEST_CASE_P(
MathExhaustiveTest_Instantiation, MathExhaustiveTest,
::testing::ValuesIn(std::vector<Testcase>{
@@ -172,7 +172,8 @@
.set_skip_neg_zero(),
Testcase{"square", Square, [](float x) { return x * x; }},
Testcase{"reciprocal", Reciprocal, [](float x) { return 1 / x; }},
- Testcase{"erfc", Erfc, std::erfc},
+ Testcase{"erf", Erf, std::erf}.set_tolerance(0.001, 0.0001),
+ Testcase{"erfc", Erfc, std::erfc}.set_tolerance(0.001, 0.0001),
Testcase{"lgamma", Lgamma, std::lgamma}
.set_tolerance(0.1, 0.15)
.set_fewer_infs_ok(),
diff --git a/tensorflow/compiler/xla/client/lib/math_test.cc b/tensorflow/compiler/xla/client/lib/math_test.cc
index 364ac58..bdfb057 100644
--- a/tensorflow/compiler/xla/client/lib/math_test.cc
+++ b/tensorflow/compiler/xla/client/lib/math_test.cc
@@ -88,6 +88,22 @@
{false, false, false, false, false, false, false, true, true}));
ComputeAndCompareLiteral(&b, expected, {});
}
+
+ void TestIsNegZero() {
+ SetFastMathDisabled(true);
+ XlaBuilder b(TestName());
+ T inf(std::numeric_limits<float>::infinity());
+ T nan(std::numeric_limits<float>::quiet_NaN());
+ IsNegZero(AddParam(
+ LiteralUtil::CreateR1<T>({T{-0.0}, T{0}, T{1}, T{-1}, inf, -inf, nan}),
+ &b));
+
+ ComputeAndCompareLiteral(
+ &b,
+ LiteralUtil::CreateR1<bool>(
+ {true, false, false, false, false, false, false}),
+ {}, error_spec_);
+ }
};
// TODO(b/123355973): Add bfloat16 to TestTypes once it's working.
@@ -102,6 +118,7 @@
XLA_TYPED_TEST(MathTypedTest, LogEdgeCases) { this->TestLogEdgeCases(); }
XLA_TYPED_TEST(MathTypedTest, Log1pEdgeCases) { this->TestLog1pEdgeCases(); }
XLA_TYPED_TEST(MathTypedTest, IsInfOrNan) { this->TestIsInfOrNan(); }
+XLA_TYPED_TEST(MathTypedTest, IsNegZero) { this->TestIsNegZero(); }
// Check that certain ops only support real, floating-point inputs.
//
diff --git a/tensorflow/compiler/xla/client/lib/matrix_test.cc b/tensorflow/compiler/xla/client/lib/matrix_test.cc
index 79cf529..61f91a5 100644
--- a/tensorflow/compiler/xla/client/lib/matrix_test.cc
+++ b/tensorflow/compiler/xla/client/lib/matrix_test.cc
@@ -24,7 +24,6 @@
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
namespace {
diff --git a/tensorflow/compiler/xla/client/lib/self_adjoint_eigen.cc b/tensorflow/compiler/xla/client/lib/self_adjoint_eigen.cc
new file mode 100644
index 0000000..c2c8cae
--- /dev/null
+++ b/tensorflow/compiler/xla/client/lib/self_adjoint_eigen.cc
@@ -0,0 +1,398 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/lib/self_adjoint_eigen.h"
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/loops.h"
+#include "tensorflow/compiler/xla/client/lib/math.h"
+#include "tensorflow/compiler/xla/client/lib/matrix.h"
+#include "tensorflow/compiler/xla/client/lib/slicing.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace xla {
+
+namespace {
+
+// Jacobi rotation (also known as Givens rotation):
+// G = [[ c, s],
+// [-s, c]]
+// matmul(G_T, G) = I
+struct SymmetricSchurDecomposition {
+ XlaOp c; // cosine.
+ XlaOp s; // sine.
+ XlaOp reduction; // Reduction in the off diagonal after applying G.
+};
+
+// JacobiUpdate holds the intermediate orthogonal matrix, Jacobi-rotated matrix
+// and the off-diagonal norm of the rotated matrix. After each Jacobi iteration,
+// off-diagonal norm is reduced.
+struct JacobiUpdate {
+ XlaOp v;
+ XlaOp w;
+ XlaOp off_diagonal_norm;
+};
+
+// Given an n-by-n symmetric A and integers p and q that satisfy 0 <= p < q < n,
+// it computes a rotation matrix G = [[c, s], [-s, c]], such that
+// G_T * A[[p, q], [p, q]] * G
+// is diagonalized.
+//
+// def sym_schur2x2(A, p, q):
+// if np.abs(A[p, q]) > 1e-6:
+// tau = (A[q, q] - A[p, p]) / (2 * A[p, q])
+// if tau >= 0:
+// t = 1.0 / (tau + np.sqrt(1 + tau ** 2))
+// else:
+// t = -1.0 / (-tau + np.sqrt(1 + tau ** 2))
+// c = 1.0 / np.sqrt(1.0 + t ** 2)
+// s = t * c
+// else:
+// c = 1.0
+// s = 0.0
+// return c, s
+StatusOr<SymmetricSchurDecomposition> SymmetricShurDecomposition2x2(XlaOp a,
+ XlaOp p,
+ XlaOp q,
+ XlaOp tol) {
+ XlaBuilder* builder = a.builder();
+ TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
+
+ PrimitiveType type = a_shape.element_type();
+
+ const int64 num_dims = a_shape.rank();
+
+ auto zero = ScalarLike(a, 0.0);
+ auto one = ScalarLike(a, 1.0);
+ auto two = ScalarLike(a, 2.0);
+
+ auto pqs = DynamicSliceInMinorDims(a, {p, q}, {1, 1});
+
+ auto ps = DynamicSliceInMinorDims(a, {p, p}, {1, 1});
+ auto qs = DynamicSliceInMinorDims(a, {q, q}, {1, 1});
+
+ auto tau = (qs - ps) / (pqs * two);
+ auto t_pos = one / (tau + Sqrt(one + Square(tau)));
+ auto t_neg = -one / (-tau + Sqrt(one + Square(tau)));
+ auto t = Select(Ge(tau, zero), t_pos, t_neg);
+
+ auto c_temp = Rsqrt(one + Square(t));
+ auto s_temp = t * c_temp;
+
+ auto c = Select(Ge(Abs(pqs), tol), c_temp, ZerosLike(c_temp) + one);
+ auto s = Select(Ge(Abs(pqs), tol), s_temp, ZerosLike(s_temp));
+ // Renormalize c and s to compensate for low precision arithmetic, this step
+ // is redundant if high precision float is used, like float64.
+ auto rnorm = Rsqrt(Square(c) + Square(s));
+
+ SymmetricSchurDecomposition schur;
+
+ schur.c = c * rnorm;
+ schur.s = s * rnorm;
+ schur.reduction =
+ Reduce(two * Square(pqs), zero, CreateScalarAddComputation(type, builder),
+ {num_dims - 2, num_dims - 1});
+ return schur;
+}
+
+StatusOr<JacobiUpdate> Update(JacobiUpdate jacobi_update, XlaOp p, XlaOp q,
+ XlaOp tol, int64 n) {
+ XlaBuilder* builder = jacobi_update.w.builder();
+ TF_ASSIGN_OR_RETURN(
+ SymmetricSchurDecomposition schur,
+ SymmetricShurDecomposition2x2(jacobi_update.w, p, q, tol));
+
+ TF_ASSIGN_OR_RETURN(Shape w_shape, builder->GetShape(jacobi_update.w));
+ const std::vector<int64> batch_dims(w_shape.dimensions().begin(),
+ w_shape.dimensions().end() - 2);
+ const int64 num_dims = w_shape.rank();
+
+ auto zero = ScalarLike(p, 0);
+
+ XlaOp c = schur.c;
+ XlaOp s = schur.s;
+
+ auto slice_p = DynamicSliceInMinorDims(jacobi_update.w, {p, zero}, {1, n});
+ auto slice_q = DynamicSliceInMinorDims(jacobi_update.w, {q, zero}, {1, n});
+
+ auto slice_p_new = c * slice_p - s * slice_q;
+ auto slice_q_new = s * slice_p + c * slice_q;
+
+ jacobi_update.w =
+ DynamicUpdateSliceInMinorDims(jacobi_update.w, slice_p_new, {p, zero});
+ jacobi_update.w =
+ DynamicUpdateSliceInMinorDims(jacobi_update.w, slice_q_new, {q, zero});
+
+ slice_p = DynamicSliceInMinorDims(jacobi_update.w, {zero, p}, {n, 1});
+ slice_q = DynamicSliceInMinorDims(jacobi_update.w, {zero, q}, {n, 1});
+
+ slice_p_new = c * slice_p - s * slice_q;
+ slice_q_new = s * slice_p + c * slice_q;
+
+ jacobi_update.w =
+ DynamicUpdateSliceInMinorDims(jacobi_update.w, slice_p_new, {zero, p});
+ jacobi_update.w =
+ DynamicUpdateSliceInMinorDims(jacobi_update.w, slice_q_new, {zero, q});
+
+ // Zero out a_{pq} explicitly.
+ std::vector<int64> pq_dims(batch_dims.begin(), batch_dims.end());
+ pq_dims.push_back(1);
+ pq_dims.push_back(1);
+ auto pq_zero = ScalarLike(jacobi_update.w, 0.0);
+ auto pq_zeros = Broadcast(pq_zero, pq_dims);
+ jacobi_update.w =
+ DynamicUpdateSliceInMinorDims(jacobi_update.w, pq_zeros, {p, q});
+ jacobi_update.w =
+ DynamicUpdateSliceInMinorDims(jacobi_update.w, pq_zeros, {q, p});
+
+ slice_p = DynamicSliceInMinorDims(jacobi_update.v, {zero, p}, {n, 1});
+ slice_q = DynamicSliceInMinorDims(jacobi_update.v, {zero, q}, {n, 1});
+
+ std::vector<int64> broadcast_dims(batch_dims.size());
+ std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
+ broadcast_dims.push_back(num_dims - 1);
+
+ // Renormalize the p-th and q-th columns. This step is redundant if high
+ // precision floats are used, like 64-bit float. But for 32-bit float, it
+ // becomes necessary. This step will not increase the overall complexity.
+ slice_p_new = c * slice_p - s * slice_q;
+ slice_p_new = Mul(
+ slice_p_new,
+ Rsqrt(Reduce(Square(slice_p_new), pq_zero,
+ CreateScalarAddComputation(w_shape.element_type(), builder),
+ {num_dims - 2})),
+ broadcast_dims);
+ slice_q_new = s * slice_p + c * slice_q;
+ slice_q_new = Mul(
+ slice_q_new,
+ Rsqrt(Reduce(Square(slice_q_new), pq_zero,
+ CreateScalarAddComputation(w_shape.element_type(), builder),
+ {num_dims - 2})),
+ broadcast_dims);
+
+ jacobi_update.v =
+ DynamicUpdateSliceInMinorDims(jacobi_update.v, slice_p_new, {zero, p});
+ jacobi_update.v =
+ DynamicUpdateSliceInMinorDims(jacobi_update.v, slice_q_new, {zero, q});
+
+ jacobi_update.off_diagonal_norm = Sqrt(
+ Max(Square(jacobi_update.off_diagonal_norm) - schur.reduction, pq_zero));
+
+ return jacobi_update;
+}
+
+StatusOr<std::vector<XlaOp>> WhileLoopFn(
+ absl::Span<const XlaOp> initial_values, //
+ int matrix_dimension, //
+ int max_sweep_updates, //
+ PrimitiveType index_type, //
+ absl::string_view name, //
+ XlaBuilder* builder) {
+ auto while_cond_fn = [&](absl::Span<const XlaOp> values,
+ XlaBuilder* cond_builder) -> StatusOr<XlaOp> {
+ auto k = values[0];
+ auto off_diagonal_norm = values[5];
+ // tol = frobenius_norm * epsilon.
+ auto tol = values[6] * values[7];
+
+ auto max_sweeps = ScalarLike(k, max_sweep_updates);
+
+ auto sweep_update_cond = Gt(max_sweeps, k);
+
+ auto tol_cond = ReduceAll(Lt(tol, off_diagonal_norm),
+ xla::ConstantR0<bool>(cond_builder, false),
+ CreateScalarOrComputation(PRED, cond_builder));
+ return And(tol_cond, sweep_update_cond);
+ };
+
+ auto while_body_fn =
+ [&](absl::Span<const XlaOp> values,
+ XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
+ auto zero = Zero(body_builder, index_type);
+ auto one = One(body_builder, index_type);
+ auto end_index = ScalarLike(one, matrix_dimension);
+
+ // Indexes.
+ XlaOp k = values[0];
+ XlaOp p = values[1];
+ XlaOp q = values[2];
+
+ JacobiUpdate jacobi_update;
+ jacobi_update.v = values[3];
+ jacobi_update.w = values[4];
+ jacobi_update.off_diagonal_norm = values[5];
+
+ XlaOp frobenius_norm = values[6];
+ XlaOp tol = values[7];
+
+ TF_ASSIGN_OR_RETURN(jacobi_update,
+ Update(jacobi_update, p, q, tol, matrix_dimension));
+
+ std::vector<XlaOp> updated_values;
+ updated_values.reserve(values.size());
+
+ q = q + one;
+ p = Select(Eq(q, end_index), p + one, p);
+ k = Select(Eq(p, end_index - one), k + one, k);
+ p = Select(Eq(p, end_index - one), zero, p);
+ q = Select(Eq(q, end_index), p + one, q);
+
+ updated_values.push_back(k);
+ updated_values.push_back(p);
+ updated_values.push_back(q);
+
+ updated_values.push_back(jacobi_update.v);
+ updated_values.push_back(jacobi_update.w);
+ updated_values.push_back(jacobi_update.off_diagonal_norm);
+
+ updated_values.push_back(frobenius_norm);
+ updated_values.push_back(tol);
+
+ return updated_values;
+ };
+ std::vector<XlaOp> values;
+ TF_ASSIGN_OR_RETURN(values, WhileLoopHelper(while_cond_fn, while_body_fn,
+ initial_values, name, builder));
+
+ return values;
+}
+
+} // namespace
+
+// This is the cyclic Jacobi iteration. Please note that the eigenvalues are
+// possibly not ordered.
+//
+// def jacobi(A):
+// n, _ = A.shape
+// V = np.eye(n)
+// nfrob = np.sum(A ** 2)
+// ndiag = np.sum(np.diag(A) ** 2)
+// off = nfrob - ndiag
+// while off > 1e-6 * nfrob:
+// for p in range(n - 1):
+// for q in range(p + 1, n):
+// if off > 1e-6 * nfrob:
+// c, s = sym_schur2x2(A, p, q)
+// off = off - 2 * A[p, q] ** 2
+// A[[p, q], :] = np.matmul(np.array([[c, -s], [s, c]]),
+// A[[p, q], :])
+// A[:, [p, q]] = np.matmul(A[:, [p, q]],
+// np.array([[c, s], [-s, c]]))
+// V[:, [p, q]] = np.matmul(V[:, [p, q]],
+// np.array([[c, s], [-s, c]]))
+//
+// return A, V
+//
+// TODO(kuny): Implement parallel order Jacobi.
+//
+SelfAdjointEigenResult SelfAdjointEigen(XlaOp a, bool lower, int64 max_iter,
+ float epsilon) {
+ XlaBuilder* builder = a.builder();
+ auto return_error = [&](const Status& status) {
+ SelfAdjointEigenResult result;
+ result.v = builder->ReportError(status);
+ result.w = builder->ReportError(status);
+ return result;
+ };
+ auto shape_with_status = builder->GetShape(a);
+ if (!shape_with_status.status().ok()) {
+ return return_error(shape_with_status.status());
+ }
+ Shape a_shape = shape_with_status.ValueOrDie();
+ const int64 num_dims = a_shape.rank();
+ if (num_dims < 2) {
+ return return_error(InvalidArgument(
+ "Arguments to Eigen decomposition must have rank >= 2: got shape %s.",
+ a_shape.ToString()));
+ }
+ PrimitiveType type = a_shape.element_type();
+ if (!primitive_util::IsFloatingPointType(type)) {
+ return return_error(InvalidArgument(
+ "Type of the input matrix must be float: got %s.", a_shape.ToString()));
+ }
+
+ const int64 m = ShapeUtil::GetDimension(a_shape, -2);
+ const int64 n = ShapeUtil::GetDimension(a_shape, -1);
+
+ if (m != n) {
+ return return_error(InvalidArgument(
+ "Arguments to Eigen decomposition must be square matrices: got shape "
+ "(%d, %d).",
+ m, n));
+ }
+
+ const int64 num_batch_dims = num_dims - 2;
+ std::vector<int64> batch_dims(num_batch_dims);
+ for (int i = 0; i < num_batch_dims; ++i) {
+ batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
+ }
+
+ auto zero = ScalarLike(a, 0.0);
+ auto tol = ScalarLike(a, epsilon);
+
+ auto v_init = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims);
+ auto w_init = Triangle(a, lower);
+ w_init = w_init + TransposeInMinorDims(w_init) - w_init * v_init;
+
+ auto frobenius_norm = Sqrt(Reduce(Square(w_init), zero,
+ CreateScalarAddComputation(type, builder),
+ {num_dims - 2, num_dims - 1}));
+ auto diag = GetMatrixDiagonal(w_init);
+ auto diag_square =
+ Reduce(Square(diag), zero, CreateScalarAddComputation(type, builder),
+ {num_dims - 2});
+
+ auto off_diagonal_init =
+ Sqrt(Max(Square(frobenius_norm) - diag_square, zero));
+
+ auto output_with_status = WhileLoopFn(
+ {
+ Zero(builder, S32), // k
+ Zero(builder, S32), // p
+ One(builder, S32), // q
+ v_init, //
+ w_init, //
+ off_diagonal_init, //
+ frobenius_norm, //
+ tol, //
+ }, //
+ n, //
+ max_iter, //
+ S32, //
+ "CyclicJacobi", //
+ builder);
+ if (!output_with_status.status().ok()) {
+ return return_error(output_with_status.status());
+ }
+
+ auto output = output_with_status.ValueOrDie();
+
+ SelfAdjointEigenResult result;
+ result.v = output[3];
+ result.w = GetMatrixDiagonal(output[4]);
+
+ return result;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/self_adjoint_eigen.h b/tensorflow/compiler/xla/client/lib/self_adjoint_eigen.h
new file mode 100644
index 0000000..49fc17a
--- /dev/null
+++ b/tensorflow/compiler/xla/client/lib/self_adjoint_eigen.h
@@ -0,0 +1,42 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SELF_ADJOINT_EIGEN_H_
+#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SELF_ADJOINT_EIGEN_H_
+
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+
+// The eigenvalue decomposition of a symmetric matrix, the original matrix is
+// recovered by v * w * v_t.
+struct SelfAdjointEigenResult {
+ // The i-th column is the normalized eigenvector corresponding to the
+ // eigenvalue w[i]. Will return a matrix object if a is a matrix object.
+ XlaOp v;
+ // TODO(kuny): Sort the eigenvalues.
+ // The eigenvalues in ascending order, each repeated according to its
+ // multiplicity.
+ XlaOp w;
+};
+
+SelfAdjointEigenResult SelfAdjointEigen(XlaOp a, bool lower = true,
+ int64 max_iter = 100,
+ float epsilon = 1e-6);
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_SELF_ADJOINT_EIGEN_H_
diff --git a/tensorflow/compiler/xla/client/lib/self_adjoint_eigen_test.cc b/tensorflow/compiler/xla/client/lib/self_adjoint_eigen_test.cc
new file mode 100644
index 0000000..720c49b
--- /dev/null
+++ b/tensorflow/compiler/xla/client/lib/self_adjoint_eigen_test.cc
@@ -0,0 +1,300 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/client/lib/self_adjoint_eigen.h"
+
+#include "tensorflow/compiler/xla/array2d.h"
+#include "tensorflow/compiler/xla/array3d.h"
+#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
+#include "tensorflow/compiler/xla/client/lib/constants.h"
+#include "tensorflow/compiler/xla/client/lib/matrix.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/compiler/xla/tests/test_macros.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+
+namespace xla {
+
+class SelfAdjointEigenTest : public ClientLibraryTestBase {
+ protected:
+ void SetUp() override {
+ ClientLibraryTestBase::SetUp();
+ batch_3d_4x4_ = Array3D<float>{
+ {
+ {4, 6, 8, 10},
+ {6, 45, 54, 63},
+ {8, 54, 146, 166},
+ {10, 63, 166, 310},
+ },
+ {
+ {16, 24, 8, 12},
+ {24, 61, 82, 48},
+ {8, 82, 100, 6},
+ {12, 48, 6, 62},
+ },
+ };
+ matrix2d_8x8_ = Array2D<float>{
+ {14., 123., 49., 112., 115., 173., 182., 125.},
+ {123., 14., 60., 118., 150., 130., 91., 72.},
+ {49., 60., 138., 111., 106., 101., 115., 142.},
+ {112., 118., 111., 142., 91., 130., 25., 61.},
+ {115., 150., 106., 91., 116., 121., 128., 85.},
+ {173., 130., 101., 130., 121., 70., 151., 132.},
+ {182., 91., 115., 25., 128., 151., 66., 92.},
+ {125., 72., 142., 61., 85., 132., 92., 156.},
+ };
+ low_rank_4x4_ = Array2D<float>{
+ // x = [[1, 2, 3, 4], [1, -1, 1, -1]]
+ // matmul(x.T, x)
+ {2, 1, 4, 3},
+ {1, 5, 5, 9},
+ {4, 5, 10, 11},
+ {3, 9, 11, 17},
+ };
+ }
+ void TearDown() override { ClientLibraryTestBase::TearDown(); }
+
+ Array3D<float> get_unit_matrix_3d(const Array3D<float>& matrix) {
+ Array3D<float> result(matrix.n1(), matrix.n2(), matrix.n3(), 0.0);
+ for (int i = 0; i < matrix.n1(); ++i) {
+ for (int j = 0; j < matrix.n2(); ++j) {
+ result({i, j, j}) = 1.0;
+ }
+ }
+ return result;
+ }
+
+ Array3D<float> ExtractTriangularMatrix(const Array3D<float>& matrix,
+ bool lower) {
+ Array3D<float> result(matrix);
+ for (int i = 0; i < result.n1(); ++i) {
+ for (int j = 0; j < result.n2(); ++j) {
+ if (lower) {
+ for (int k = j + 1; k < result.n3(); ++k) {
+ result({i, j, k}) = 0.0;
+ }
+ } else {
+ for (int k = 0; k < j; ++k) {
+ result({i, j, k}) = 0.0;
+ }
+ }
+ }
+ }
+ return result;
+ }
+
+ XlaOp ComputeMatmulVWVt(SelfAdjointEigenResult result, XlaBuilder* builder) {
+ Shape shape = builder->GetShape(result.v).ValueOrDie();
+ std::vector<int64> out_dims = shape.dimensions();
+ std::vector<int64> broadcast_dims(shape.rank() - 1);
+ std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0);
+
+ broadcast_dims[shape.rank() - 2] = shape.rank() - 1;
+ auto vw = Mul(result.v, BroadcastInDim(result.w, out_dims, broadcast_dims));
+ return BatchDot(vw, TransposeInMinorDims(result.v),
+ PrecisionConfig::HIGHEST);
+ }
+
+ XlaOp GetAverageAbsoluteError(XlaOp m1, XlaOp m2, XlaBuilder* builder) {
+ Shape shape = builder->GetShape(m1).ValueOrDie();
+ int64 size = 1;
+ for (auto d : shape.dimensions()) {
+ size *= d;
+ }
+ return ReduceAll(Abs(m1 - m2), ConstantR0WithType(builder, F32, 0),
+ CreateScalarAddComputation(F32, builder)) /
+ ConstantR0WithType(builder, F32, size);
+ }
+
+ Array2D<float> GenerateRandomSymmetricMatrix(int size) {
+ Array2D<float> result{size, size, 0.0};
+ result.FillRandom(10 /* stddev */, 2 /* mean */);
+ for (int i = 0; i < size; ++i) {
+ for (int j = 0; j < i; ++j) {
+ result({j, i}) = result({i, j});
+ }
+ }
+ return result;
+ }
+
+ Array3D<float> batch_3d_4x4_;
+ Array2D<float> matrix2d_8x8_;
+ Array2D<float> low_rank_4x4_;
+ Array2D<int> wrong_type_4x4_;
+};
+
+XLA_TEST_F(SelfAdjointEigenTest, Test_VWVt_EQ_A_2x4x4) {
+ XlaBuilder builder(TestName());
+
+ XlaOp a;
+ auto a_data = CreateR3Parameter<float>(batch_3d_4x4_, 0, "a", &builder, &a);
+ auto result = SelfAdjointEigen(a);
+ ComputeMatmulVWVt(result, &builder);
+
+ ComputeAndCompareR3<float>(&builder, batch_3d_4x4_, {a_data.get()},
+ ErrorSpec(1e-3, 1e-3));
+}
+
+XLA_TEST_F(SelfAdjointEigenTest, Test_VWVt_EQ_A_Lower_2x4x4) {
+ XlaBuilder builder(TestName());
+
+ XlaOp a;
+ auto a_data = CreateR3Parameter<float>(
+ ExtractTriangularMatrix(batch_3d_4x4_, true), 0, "a", &builder, &a);
+ auto result = SelfAdjointEigen(a);
+ ComputeMatmulVWVt(result, &builder);
+
+ ComputeAndCompareR3<float>(&builder, batch_3d_4x4_, {a_data.get()},
+ ErrorSpec(1e-3, 1e-3));
+}
+
+XLA_TEST_F(SelfAdjointEigenTest, Test_VWVt_EQ_A_Upper_2x4x4) {
+ XlaBuilder builder(TestName());
+
+ XlaOp a;
+ auto a_data = CreateR3Parameter<float>(
+ ExtractTriangularMatrix(batch_3d_4x4_, false), 0, "a", &builder, &a);
+ auto result = SelfAdjointEigen(a, false);
+ ComputeMatmulVWVt(result, &builder);
+
+ ComputeAndCompareR3<float>(&builder, batch_3d_4x4_, {a_data.get()},
+ ErrorSpec(1e-3, 1e-3));
+}
+
+XLA_TEST_F(SelfAdjointEigenTest, Test_Orthogonality_2x4x4) {
+ XlaBuilder builder(TestName());
+
+ XlaOp a;
+ auto a_data = CreateR3Parameter<float>(batch_3d_4x4_, 0, "a", &builder, &a);
+ auto result = SelfAdjointEigen(a);
+ BatchDot(result.v, TransposeInMinorDims(result.v), PrecisionConfig::HIGHEST);
+
+ ComputeAndCompareR3<float>(&builder, get_unit_matrix_3d(batch_3d_4x4_),
+ {a_data.get()}, ErrorSpec(1e-3, 1e-3));
+}
+
+XLA_TEST_F(SelfAdjointEigenTest, Test_VtWV_EQ_A_Rank_Deficient_4x4) {
+ XlaBuilder builder(TestName());
+
+ XlaOp a;
+ auto a_data = CreateR2Parameter<float>(low_rank_4x4_, 0, "a", &builder, &a);
+ auto result = SelfAdjointEigen(a);
+ ComputeMatmulVWVt(result, &builder);
+
+ ComputeAndCompareR2<float>(&builder, low_rank_4x4_, {a_data.get()},
+ ErrorSpec(1e-3, 1e-3));
+}
+
+XLA_TEST_F(SelfAdjointEigenTest, Test_Eigen_8x8) {
+ XlaBuilder builder(TestName());
+
+ // This is computed by numpy.linalg.eigh with float32.
+ std::vector<float> expected{-182.69205, -116.86245, -105.74489, -9.545369,
+ 37.81711, 104.732285, 120.29153, 868.00385};
+
+ XlaOp a;
+ auto a_data = CreateR2Parameter<float>(matrix2d_8x8_, 0, "a", &builder, &a);
+ auto result = SelfAdjointEigen(a);
+ Sort(result.w);
+
+ ComputeAndCompareR1<float>(&builder, expected, {a_data.get()},
+ ErrorSpec(1e-3, 1e-3));
+}
+
+XLA_TEST_F(SelfAdjointEigenTest, Test_Orthogonality_8x8) {
+ XlaBuilder builder(TestName());
+
+ float expected_vals = 1e-3;
+
+ XlaOp a;
+ auto a_data = CreateR2Parameter<float>(matrix2d_8x8_, 0, "a", &builder, &a);
+ auto result = SelfAdjointEigen(a);
+ // np.sum(norm(eye(n) - matmul(conj(T(v)), v)) / n**2
+ GetAverageAbsoluteError(IdentityMatrix(&builder, F32, 8, 8),
+ BatchDot(TransposeInMinorDims(result.v), result.v),
+ &builder);
+
+ ComputeAndCompareR0<float>(&builder, expected_vals, {a_data.get()},
+ ErrorSpec(1e-3, 1e-3));
+}
+
+XLA_TEST_F(SelfAdjointEigenTest, Wrong_Type_Int) {
+ XlaBuilder builder(TestName());
+
+ XlaOp a;
+ auto a_data = CreateR2Parameter<int>(wrong_type_4x4_, 0, "a", &builder, &a);
+ auto result = SelfAdjointEigen(a);
+ EXPECT_FALSE(result.v.valid());
+ EXPECT_FALSE(result.w.valid());
+}
+
+XLA_TEST_F(SelfAdjointEigenTest, Various_Size_Random_Matrix_8x8) {
+ XlaBuilder builder(TestName());
+ int size = 8;
+ Array2D<float> a_val = GenerateRandomSymmetricMatrix(size);
+ XlaOp a;
+ auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
+ auto result = SelfAdjointEigen(a);
+ GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder);
+
+ ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
+ ErrorSpec(1e-2, 1e-2));
+}
+
+XLA_TEST_F(SelfAdjointEigenTest, Various_Size_Random_Matrix_16x16) {
+ XlaBuilder builder(TestName());
+ int size = 16;
+ Array2D<float> a_val = GenerateRandomSymmetricMatrix(size);
+ XlaOp a;
+ auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
+ auto result = SelfAdjointEigen(a);
+ GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder);
+
+ ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
+ ErrorSpec(1e-2, 1e-2));
+}
+
+XLA_TEST_F(SelfAdjointEigenTest, Various_Size_Random_Matrix_32x32) {
+ XlaBuilder builder(TestName());
+ int size = 32;
+ Array2D<float> a_val = GenerateRandomSymmetricMatrix(size);
+ XlaOp a;
+ auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
+ auto result = SelfAdjointEigen(a);
+ GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder);
+
+ ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
+ ErrorSpec(1e-2, 1e-2));
+}
+
+XLA_TEST_F(SelfAdjointEigenTest, Various_Size_Random_Matrix_64x64) {
+ XlaBuilder builder(TestName());
+ int size = 64;
+ Array2D<float> a_val = GenerateRandomSymmetricMatrix(size);
+ XlaOp a;
+ auto a_data = CreateR2Parameter<float>(a_val, 0, "a", &builder, &a);
+ auto result = SelfAdjointEigen(a);
+ GetAverageAbsoluteError(ComputeMatmulVWVt(result, &builder), a, &builder);
+
+ ComputeAndCompareR0<float>(&builder, 1e-3, {a_data.get()},
+ ErrorSpec(1e-2, 1e-2));
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/triangular_solve.cc b/tensorflow/compiler/xla/client/lib/triangular_solve.cc
index ba7fde1..1515e9e 100644
--- a/tensorflow/compiler/xla/client/lib/triangular_solve.cc
+++ b/tensorflow/compiler/xla/client/lib/triangular_solve.cc
@@ -346,8 +346,8 @@
}
XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
- bool transpose_a, bool conjugate_a, int64 block_size,
- PrecisionConfig::Precision precision) {
+ bool transpose_a, bool conjugate_a, bool unit_diagonal,
+ int64 block_size, PrecisionConfig::Precision precision) {
XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
@@ -406,6 +406,20 @@
return b;
}
+ // TODO(phawkins): consider pushing triangle masking into
+ // InvertDiagonalBlocks.
+ if (unit_diagonal) {
+ // Mask everything but the subdiagonal/superdiagonal elements.
+ a = lower ? Select(TriangleMask(a, -1), a, ZerosLike(a))
+ : Select(TriangleMask(a, 0), ZerosLike(a), a);
+ int64 k = ShapeUtil::GetDimension(a_shape, -1);
+ a = xla::Add(a, IdentityMatrix(builder, a_shape.element_type(), k, k),
+ /*broadcast_dimensions=*/{ndims - 2, ndims - 1});
+ } else {
+ // Mask off the ignored elements of the triangular matrix a.
+ a = Triangle(a, lower);
+ }
+
// We find the diagonal blocks of the coefficient matrix
auto diag_blocks = DiagonalBlocks(a, block_size);
@@ -413,11 +427,6 @@
auto inv_diag_blocks = InvertDiagonalBlocks(diag_blocks, lower, transpose_a,
conjugate_a, precision);
- // Mask off the ignored elements of the triangular matrix a.
- // TODO(phawkins): it would probably be preferable to perform this masking
- // block by block inside SolveWithInvertedDiagonalBlocks.
- a = Triangle(a, lower);
-
// We now find the solution using GEMMs
auto x =
SolveWithInvertedDiagonalBlocks(a, b, inv_diag_blocks, left_side, lower,
diff --git a/tensorflow/compiler/xla/client/lib/triangular_solve.h b/tensorflow/compiler/xla/client/lib/triangular_solve.h
index 50a3b30..b87ef72 100644
--- a/tensorflow/compiler/xla/client/lib/triangular_solve.h
+++ b/tensorflow/compiler/xla/client/lib/triangular_solve.h
@@ -54,12 +54,14 @@
// `conjugate_a` is a boolean indicating whether the entries of `a` are complex
// conjugated (independently of whether they are transposed), so that when both
// transpose_a and conjugate_a are true the effect is a Hermitian adjoint.
+// If `unit_diagonal` elements on the matrix diagonal are assumed to be '1' and
+// are not read by the triangular solve..
//
// Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no
// blocking is used.
XlaOp TriangularSolve(
XlaOp a, XlaOp b, bool left_side, bool lower, bool transpose_a,
- bool conjugate_a, int64 block_size = 128,
+ bool conjugate_a, bool unit_diagonal, int64 block_size = 128,
PrecisionConfig::Precision precision = PrecisionConfig::HIGHEST);
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/lib/triangular_solve_test.cc b/tensorflow/compiler/xla/client/lib/triangular_solve_test.cc
index 284a2e9..b333ffa 100644
--- a/tensorflow/compiler/xla/client/lib/triangular_solve_test.cc
+++ b/tensorflow/compiler/xla/client/lib/triangular_solve_test.cc
@@ -54,6 +54,20 @@
{kNan, kNan, kNan, 11}};
}
+Array2D<float> AValsLowerUnitDiagonal() {
+ return {{kNan, kNan, kNan, kNan},
+ {3, kNan, kNan, kNan},
+ {4, 7, kNan, kNan},
+ {5, 8, 10, kNan}};
+}
+
+Array2D<float> AValsUpperUnitDiagonal() {
+ return {{kNan, 3, 4, 5},
+ {kNan, kNan, 7, 8},
+ {kNan, kNan, kNan, 10},
+ {kNan, kNan, kNan, kNan}};
+}
+
Array2D<float> BValsRight() {
return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}};
}
@@ -97,6 +111,7 @@
TriangularSolve(a, b,
/*left_side=*/true, /*lower=*/true,
/*transpose_a=*/true, /*conjugate_a=*/false,
+ /*unit_diagonal=*/false,
/*block_size=*/2);
ComputeAndCompareR2<float>(&builder, Array2D<float>(0, 10),
@@ -112,6 +127,7 @@
TriangularSolve(a, b,
/*left_side=*/false, /*lower=*/true,
/*transpose_a=*/true, /*conjugate_a=*/false,
+ /*unit_diagonal=*/false,
/*block_size=*/2);
Array2D<float> expected({
@@ -133,6 +149,7 @@
TriangularSolve(a, b,
/*left_side=*/false, /*lower=*/true,
/*transpose_a=*/false, /*conjugate_a=*/false,
+ /*unit_diagonal=*/false,
/*block_size=*/2);
Array2D<float> expected({
@@ -154,6 +171,7 @@
TriangularSolve(a, b,
/*left_side=*/false, /*lower=*/false,
/*transpose_a=*/true, /*conjugate_a=*/false,
+ /*unit_diagonal=*/false,
/*block_size=*/2);
Array2D<float> expected({
@@ -175,6 +193,7 @@
TriangularSolve(a, b,
/*left_side=*/false, /*lower=*/false,
/*transpose_a=*/false, /*conjugate_a=*/false,
+ /*unit_diagonal=*/false,
/*block_size=*/2);
Array2D<float> expected({
@@ -196,6 +215,7 @@
TriangularSolve(a, b,
/*left_side=*/true, /*lower=*/true,
/*transpose_a=*/true, /*conjugate_a=*/false,
+ /*unit_diagonal=*/false,
/*block_size=*/2);
Array2D<float> expected({
@@ -218,6 +238,7 @@
TriangularSolve(a, b,
/*left_side=*/true, /*lower=*/true,
/*transpose_a=*/false, /*conjugate_a=*/false,
+ /*unit_diagonal=*/false,
/*block_size=*/2);
Array2D<float> expected({
@@ -231,6 +252,26 @@
ErrorSpec(1e-2, 1e-2));
}
+XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNoTransposeUnitDiagonal) {
+ XlaBuilder builder(TestName());
+
+ XlaOp a, b;
+ auto a_data =
+ CreateR2Parameter<float>(AValsLowerUnitDiagonal(), 0, "a", &builder, &a);
+ auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
+ TriangularSolve(a, b,
+ /*left_side=*/true, /*lower=*/true,
+ /*transpose_a=*/false, /*conjugate_a=*/false,
+ /*unit_diagonal=*/true,
+ /*block_size=*/2);
+
+ Array2D<float> expected(
+ {{1., 2., 3.}, {1., -1., -3.}, {-4., 7., 18.}, {37., -61., -159.}});
+
+ ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
+ ErrorSpec(1e-2, 1e-2));
+}
+
XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotransposeIrregularblock) {
XlaBuilder builder(TestName());
@@ -240,6 +281,7 @@
TriangularSolve(a, b,
/*left_side=*/true, /*lower=*/true,
/*transpose_a=*/false, /*conjugate_a=*/false,
+ /*unit_diagonal=*/false,
/*block_size=*/3);
Array2D<float> expected({
@@ -262,6 +304,7 @@
TriangularSolve(a, b,
/*left_side=*/true, /*lower=*/false,
/*transpose_a=*/true, /*conjugate_a=*/false,
+ /*unit_diagonal=*/false,
/*block_size=*/2);
Array2D<float> expected({
@@ -284,6 +327,7 @@
TriangularSolve(a, b,
/*left_side=*/true, /*lower=*/false,
/*transpose_a=*/false, /*conjugate_a=*/false,
+ /*unit_diagonal=*/false,
/*block_size=*/2);
Array2D<float> expected({
@@ -297,6 +341,28 @@
ErrorSpec(1e-2, 1e-2));
}
+XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotransposeUnitDiagonal) {
+ XlaBuilder builder(TestName());
+
+ XlaOp a, b;
+ auto a_data =
+ CreateR2Parameter<float>(AValsUpperUnitDiagonal(), 0, "a", &builder, &a);
+ auto b_data = CreateR2Parameter<float>(BValsLeft(), 1, "b", &builder, &b);
+ TriangularSolve(a, b,
+ /*left_side=*/true, /*lower=*/false,
+ /*transpose_a=*/false, /*conjugate_a=*/false,
+ /*unit_diagonal=*/true,
+ /*block_size=*/2);
+
+ Array2D<float> expected({{-1402., -1538., -1674.},
+ {575., 631., 687.},
+ {-93., -102., -111.},
+ {10., 11., 12.}});
+
+ ComputeAndCompareR2<float>(&builder, expected, {a_data.get(), b_data.get()},
+ ErrorSpec(1e-2, 1e-2));
+}
+
XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) {
XlaBuilder builder(TestName());
@@ -308,6 +374,7 @@
TriangularSolve(a, b,
/*left_side=*/false, /*lower=*/true,
/*transpose_a=*/true, /*conjugate_a=*/true,
+ /*unit_diagonal=*/false,
/*block_size=*/2);
Array2D<complex64> expected({
@@ -334,6 +401,7 @@
TriangularSolve(a, b,
/*left_side=*/true, /*lower=*/false,
/*transpose_a=*/true, /*conjugate_a=*/false,
+ /*unit_diagonal=*/false,
/*block_size=*/2);
Array2D<complex64> expected({
@@ -372,6 +440,7 @@
TriangularSolve(a, b,
/*left_side=*/true, /*lower=*/false,
/*transpose_a=*/false, /*conjugate_a=*/false,
+ /*unit_diagonal=*/false,
/*block_size=*/2));
ComputeAndCompareR3<float>(&builder, bvals, {a_data.get(), b_data.get()},
@@ -409,7 +478,7 @@
auto a_data = CreateR2Parameter<float>(avals, 0, "a", &builder, &a);
auto b_data = CreateR2Parameter<float>(bvals, 1, "b", &builder, &b);
auto x = TriangularSolve(a, b, spec.left_side, spec.lower, spec.transpose_a,
- /*conjugate_a=*/false,
+ /*conjugate_a=*/false, /*unit_diagonal=*/false,
/*block_size=*/3);
auto a_tri = Triangle(a, spec.lower);
a_tri = MaybeTransposeInMinorDims(a_tri, spec.transpose_a);
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index c00ba26..b64d352 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -573,16 +573,6 @@
});
}
-XlaOp XlaBuilder::Add(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions) {
- return BinaryOp(HloOpcode::kAdd, lhs, rhs, broadcast_dimensions);
-}
-
-XlaOp XlaBuilder::Mul(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions) {
- return BinaryOp(HloOpcode::kMultiply, lhs, rhs, broadcast_dimensions);
-}
-
XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
@@ -1004,36 +994,6 @@
});
}
-XlaOp XlaBuilder::Eq(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions) {
- return BinaryOp(HloOpcode::kEq, lhs, rhs, broadcast_dimensions);
-}
-
-XlaOp XlaBuilder::Ne(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions) {
- return BinaryOp(HloOpcode::kNe, lhs, rhs, broadcast_dimensions);
-}
-
-XlaOp XlaBuilder::Ge(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions) {
- return BinaryOp(HloOpcode::kGe, lhs, rhs, broadcast_dimensions);
-}
-
-XlaOp XlaBuilder::Gt(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions) {
- return BinaryOp(HloOpcode::kGt, lhs, rhs, broadcast_dimensions);
-}
-
-XlaOp XlaBuilder::Le(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions) {
- return BinaryOp(HloOpcode::kLe, lhs, rhs, broadcast_dimensions);
-}
-
-XlaOp XlaBuilder::Lt(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions) {
- return BinaryOp(HloOpcode::kLt, lhs, rhs, broadcast_dimensions);
-}
-
XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs,
const PrecisionConfig* precision_config) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
@@ -1549,147 +1509,6 @@
});
}
-XlaOp XlaBuilder::Complex(const XlaOp& real, const XlaOp& imag,
- absl::Span<const int64> broadcast_dimensions) {
- return BinaryOp(HloOpcode::kComplex, real, imag, broadcast_dimensions);
-}
-
-XlaOp XlaBuilder::Conj(const XlaOp& operand) {
- return Complex(Real(operand), Neg(Imag(operand)));
-}
-
-XlaOp XlaBuilder::Sub(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions) {
- return BinaryOp(HloOpcode::kSubtract, lhs, rhs, broadcast_dimensions);
-}
-
-XlaOp XlaBuilder::Div(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions) {
- return BinaryOp(HloOpcode::kDivide, lhs, rhs, broadcast_dimensions);
-}
-
-XlaOp XlaBuilder::Rem(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions) {
- return BinaryOp(HloOpcode::kRemainder, lhs, rhs, broadcast_dimensions);
-}
-
-XlaOp XlaBuilder::Max(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions) {
- return BinaryOp(HloOpcode::kMaximum, lhs, rhs, broadcast_dimensions);
-}
-
-XlaOp XlaBuilder::Min(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions) {
- return BinaryOp(HloOpcode::kMinimum, lhs, rhs, broadcast_dimensions);
-}
-
-XlaOp XlaBuilder::And(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions) {
- return BinaryOp(HloOpcode::kAnd, lhs, rhs, broadcast_dimensions);
-}
-
-XlaOp XlaBuilder::Or(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions) {
- return BinaryOp(HloOpcode::kOr, lhs, rhs, broadcast_dimensions);
-}
-
-XlaOp XlaBuilder::Xor(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions) {
- return BinaryOp(HloOpcode::kXor, lhs, rhs, broadcast_dimensions);
-}
-
-XlaOp XlaBuilder::Not(const XlaOp& operand) {
- return UnaryOp(HloOpcode::kNot, operand);
-}
-
-XlaOp XlaBuilder::ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions) {
- return BinaryOp(HloOpcode::kShiftLeft, lhs, rhs, broadcast_dimensions);
-}
-
-XlaOp XlaBuilder::ShiftRightArithmetic(
- const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions) {
- return BinaryOp(HloOpcode::kShiftRightArithmetic, lhs, rhs,
- broadcast_dimensions);
-}
-
-XlaOp XlaBuilder::ShiftRightLogical(
- const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions) {
- return BinaryOp(HloOpcode::kShiftRightLogical, lhs, rhs,
- broadcast_dimensions);
-}
-
-XlaOp XlaBuilder::Abs(const XlaOp& operand) {
- return UnaryOp(HloOpcode::kAbs, operand);
-}
-
-XlaOp XlaBuilder::Atan2(const XlaOp& y, const XlaOp& x,
- absl::Span<const int64> broadcast_dimensions) {
- return BinaryOp(HloOpcode::kAtan2, y, x, broadcast_dimensions);
-}
-
-XlaOp XlaBuilder::Exp(const XlaOp& operand) {
- return UnaryOp(HloOpcode::kExp, operand);
-}
-
-XlaOp XlaBuilder::Expm1(const XlaOp& operand) {
- return UnaryOp(HloOpcode::kExpm1, operand);
-}
-
-XlaOp XlaBuilder::Floor(const XlaOp& operand) {
- return UnaryOp(HloOpcode::kFloor, operand);
-}
-
-XlaOp XlaBuilder::Ceil(const XlaOp& operand) {
- return UnaryOp(HloOpcode::kCeil, operand);
-}
-
-XlaOp XlaBuilder::Round(const XlaOp& operand) {
- return UnaryOp(HloOpcode::kRoundNearestAfz, operand);
-}
-
-XlaOp XlaBuilder::Log(const XlaOp& operand) {
- return UnaryOp(HloOpcode::kLog, operand);
-}
-
-XlaOp XlaBuilder::Log1p(const XlaOp& operand) {
- return UnaryOp(HloOpcode::kLog1p, operand);
-}
-
-XlaOp XlaBuilder::Sign(const XlaOp& operand) {
- return UnaryOp(HloOpcode::kSign, operand);
-}
-
-XlaOp XlaBuilder::Clz(const XlaOp& operand) {
- return UnaryOp(HloOpcode::kClz, operand);
-}
-
-XlaOp XlaBuilder::Cos(const XlaOp& operand) {
- return UnaryOp(HloOpcode::kCos, operand);
-}
-
-XlaOp XlaBuilder::Sin(const XlaOp& operand) {
- return UnaryOp(HloOpcode::kSin, operand);
-}
-
-XlaOp XlaBuilder::Tanh(const XlaOp& operand) {
- return UnaryOp(HloOpcode::kTanh, operand);
-}
-
-XlaOp XlaBuilder::Real(const XlaOp& operand) {
- return UnaryOp(HloOpcode::kReal, operand);
-}
-
-XlaOp XlaBuilder::Imag(const XlaOp& operand) {
- return UnaryOp(HloOpcode::kImag, operand);
-}
-
-XlaOp XlaBuilder::IsFinite(const XlaOp& operand) {
- return UnaryOp(HloOpcode::kIsFinite, operand);
-}
-
XlaOp XlaBuilder::Transpose(const XlaOp& operand,
absl::Span<const int64> permutation) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
@@ -1745,11 +1564,6 @@
});
}
-XlaOp XlaBuilder::Pow(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions) {
- return BinaryOp(HloOpcode::kPower, lhs, rhs, broadcast_dimensions);
-}
-
XlaOp XlaBuilder::ConvertElementType(const XlaOp& operand,
PrimitiveType new_element_type) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
@@ -1775,10 +1589,6 @@
});
}
-XlaOp XlaBuilder::Neg(const XlaOp& operand) {
- return UnaryOp(HloOpcode::kNegate, operand);
-}
-
XlaOp XlaBuilder::Clamp(const XlaOp& min, const XlaOp& operand,
const XlaOp& max) {
return TernaryOp(HloOpcode::kClamp, min, operand, max);
@@ -2159,8 +1969,8 @@
TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
const Shape& scalar_shape = ShapeUtil::MakeShape(shape.element_type(), {});
auto b = CreateSubBuilder("sum");
- b->Add(b->Parameter(/*parameter_number=*/0, scalar_shape, "x"),
- b->Parameter(/*parameter_number=*/1, scalar_shape, "y"));
+ Add(b->Parameter(/*parameter_number=*/0, scalar_shape, "x"),
+ b->Parameter(/*parameter_number=*/1, scalar_shape, "y"));
TF_ASSIGN_OR_RETURN(auto computation, b->Build());
return CrossReplicaSum(operand, computation, replica_groups,
/*channel_id=*/absl::nullopt);
@@ -2956,32 +2766,38 @@
XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions) {
- return lhs.builder()->Eq(lhs, rhs, broadcast_dimensions);
+ return lhs.builder()->BinaryOp(HloOpcode::kEq, lhs, rhs,
+ broadcast_dimensions);
}
XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions) {
- return lhs.builder()->Ne(lhs, rhs, broadcast_dimensions);
+ return lhs.builder()->BinaryOp(HloOpcode::kNe, lhs, rhs,
+ broadcast_dimensions);
}
XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions) {
- return lhs.builder()->Ge(lhs, rhs, broadcast_dimensions);
+ return lhs.builder()->BinaryOp(HloOpcode::kGe, lhs, rhs,
+ broadcast_dimensions);
}
XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions) {
- return lhs.builder()->Gt(lhs, rhs, broadcast_dimensions);
-}
-
-XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions) {
- return lhs.builder()->Lt(lhs, rhs, broadcast_dimensions);
+ return lhs.builder()->BinaryOp(HloOpcode::kGt, lhs, rhs,
+ broadcast_dimensions);
}
XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions) {
- return lhs.builder()->Le(lhs, rhs, broadcast_dimensions);
+ return lhs.builder()->BinaryOp(HloOpcode::kLe, lhs, rhs,
+ broadcast_dimensions);
+}
+
+XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> broadcast_dimensions) {
+ return lhs.builder()->BinaryOp(HloOpcode::kLt, lhs, rhs,
+ broadcast_dimensions);
}
XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
@@ -3084,78 +2900,96 @@
operand_shapes_with_layout);
}
-XlaOp Complex(const XlaOp& real, const XlaOp& imag,
+XlaOp Complex(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions) {
- return real.builder()->Complex(real, imag, broadcast_dimensions);
+ return lhs.builder()->BinaryOp(HloOpcode::kComplex, lhs, rhs,
+ broadcast_dimensions);
}
-XlaOp Conj(const XlaOp& operand) { return operand.builder()->Conj(operand); }
+XlaOp Conj(const XlaOp& operand) {
+ return Complex(Real(operand), Neg(Imag(operand)));
+}
XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions) {
- return lhs.builder()->Add(lhs, rhs, broadcast_dimensions);
+ return lhs.builder()->BinaryOp(HloOpcode::kAdd, lhs, rhs,
+ broadcast_dimensions);
}
XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions) {
- return lhs.builder()->Sub(lhs, rhs, broadcast_dimensions);
+ return lhs.builder()->BinaryOp(HloOpcode::kSubtract, lhs, rhs,
+ broadcast_dimensions);
}
XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions) {
- return lhs.builder()->Mul(lhs, rhs, broadcast_dimensions);
+ return lhs.builder()->BinaryOp(HloOpcode::kMultiply, lhs, rhs,
+ broadcast_dimensions);
}
XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions) {
- return lhs.builder()->Div(lhs, rhs, broadcast_dimensions);
+ return lhs.builder()->BinaryOp(HloOpcode::kDivide, lhs, rhs,
+ broadcast_dimensions);
}
XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions) {
- return lhs.builder()->Rem(lhs, rhs, broadcast_dimensions);
+ return lhs.builder()->BinaryOp(HloOpcode::kRemainder, lhs, rhs,
+ broadcast_dimensions);
}
XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions) {
- return lhs.builder()->Max(lhs, rhs, broadcast_dimensions);
+ return lhs.builder()->BinaryOp(HloOpcode::kMaximum, lhs, rhs,
+ broadcast_dimensions);
}
XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions) {
- return lhs.builder()->Min(lhs, rhs, broadcast_dimensions);
+ return lhs.builder()->BinaryOp(HloOpcode::kMinimum, lhs, rhs,
+ broadcast_dimensions);
}
XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions) {
- return lhs.builder()->And(lhs, rhs, broadcast_dimensions);
+ return lhs.builder()->BinaryOp(HloOpcode::kAnd, lhs, rhs,
+ broadcast_dimensions);
}
XlaOp Or(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions) {
- return lhs.builder()->Or(lhs, rhs, broadcast_dimensions);
+ return lhs.builder()->BinaryOp(HloOpcode::kOr, lhs, rhs,
+ broadcast_dimensions);
}
XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions) {
- return lhs.builder()->Xor(lhs, rhs, broadcast_dimensions);
+ return lhs.builder()->BinaryOp(HloOpcode::kXor, lhs, rhs,
+ broadcast_dimensions);
}
-XlaOp Not(const XlaOp& operand) { return operand.builder()->Not(operand); }
+XlaOp Not(const XlaOp& operand) {
+ return operand.builder()->UnaryOp(HloOpcode::kNot, operand);
+}
XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions) {
- return lhs.builder()->ShiftLeft(lhs, rhs, broadcast_dimensions);
+ return lhs.builder()->BinaryOp(HloOpcode::kShiftLeft, lhs, rhs,
+ broadcast_dimensions);
}
XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions) {
- return lhs.builder()->ShiftRightArithmetic(lhs, rhs, broadcast_dimensions);
+ return lhs.builder()->BinaryOp(HloOpcode::kShiftRightArithmetic, lhs, rhs,
+ broadcast_dimensions);
}
XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions) {
- return lhs.builder()->ShiftRightLogical(lhs, rhs, broadcast_dimensions);
+ return lhs.builder()->BinaryOp(HloOpcode::kShiftRightLogical, lhs, rhs,
+ broadcast_dimensions);
}
XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
@@ -3250,48 +3084,67 @@
init_value, scatter);
}
-XlaOp Abs(const XlaOp& operand) { return operand.builder()->Abs(operand); }
-
-XlaOp Atan2(const XlaOp& y, const XlaOp& x,
- absl::Span<const int64> broadcast_dimensions) {
- return y.builder()->Atan2(y, x, broadcast_dimensions);
+XlaOp Abs(const XlaOp& operand) {
+ return operand.builder()->UnaryOp(HloOpcode::kAbs, operand);
}
-XlaOp Exp(const XlaOp& operand) { return operand.builder()->Exp(operand); }
+XlaOp Atan2(const XlaOp& lhs, const XlaOp& rhs,
+ absl::Span<const int64> broadcast_dimensions) {
+ return lhs.builder()->BinaryOp(HloOpcode::kAtan2, lhs, rhs,
+ broadcast_dimensions);
+}
-XlaOp Expm1(const XlaOp& operand) { return operand.builder()->Expm1(operand); }
-
-XlaOp Floor(const XlaOp& operand) { return operand.builder()->Floor(operand); }
-
-XlaOp Ceil(const XlaOp& operand) { return operand.builder()->Ceil(operand); }
-
-XlaOp Round(const XlaOp& operand) { return operand.builder()->Round(operand); }
-
-XlaOp Log(const XlaOp& operand) { return operand.builder()->Log(operand); }
-
-XlaOp Log1p(const XlaOp& operand) { return operand.builder()->Log1p(operand); }
-
-XlaOp Sign(const XlaOp& operand) { return operand.builder()->Sign(operand); }
-
-XlaOp Clz(const XlaOp& operand) { return operand.builder()->Clz(operand); }
-
-XlaOp Cos(const XlaOp& operand) { return operand.builder()->Cos(operand); }
-
-XlaOp Sin(const XlaOp& operand) { return operand.builder()->Sin(operand); }
-
-XlaOp Tanh(const XlaOp& operand) { return operand.builder()->Tanh(operand); }
-
-XlaOp Real(const XlaOp& operand) { return operand.builder()->Real(operand); }
-
-XlaOp Imag(const XlaOp& operand) { return operand.builder()->Imag(operand); }
+XlaOp Exp(const XlaOp& operand) {
+ return operand.builder()->UnaryOp(HloOpcode::kExp, operand);
+}
+XlaOp Expm1(const XlaOp& operand) {
+ return operand.builder()->UnaryOp(HloOpcode::kExpm1, operand);
+}
+XlaOp Floor(const XlaOp& operand) {
+ return operand.builder()->UnaryOp(HloOpcode::kFloor, operand);
+}
+XlaOp Ceil(const XlaOp& operand) {
+ return operand.builder()->UnaryOp(HloOpcode::kCeil, operand);
+}
+XlaOp Round(const XlaOp& operand) {
+ return operand.builder()->UnaryOp(HloOpcode::kRoundNearestAfz, operand);
+}
+XlaOp Log(const XlaOp& operand) {
+ return operand.builder()->UnaryOp(HloOpcode::kLog, operand);
+}
+XlaOp Log1p(const XlaOp& operand) {
+ return operand.builder()->UnaryOp(HloOpcode::kLog1p, operand);
+}
+XlaOp Sign(const XlaOp& operand) {
+ return operand.builder()->UnaryOp(HloOpcode::kSign, operand);
+}
+XlaOp Clz(const XlaOp& operand) {
+ return operand.builder()->UnaryOp(HloOpcode::kClz, operand);
+}
+XlaOp Cos(const XlaOp& operand) {
+ return operand.builder()->UnaryOp(HloOpcode::kCos, operand);
+}
+XlaOp Sin(const XlaOp& operand) {
+ return operand.builder()->UnaryOp(HloOpcode::kSin, operand);
+}
+XlaOp Tanh(const XlaOp& operand) {
+ return operand.builder()->UnaryOp(HloOpcode::kTanh, operand);
+}
+XlaOp Real(const XlaOp& operand) {
+ return operand.builder()->UnaryOp(HloOpcode::kReal, operand);
+}
+XlaOp Imag(const XlaOp& operand) {
+ return operand.builder()->UnaryOp(HloOpcode::kImag, operand);
+}
XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
absl::Span<const int64> broadcast_dimensions) {
- return lhs.builder()->Pow(lhs, rhs, broadcast_dimensions);
+ return lhs.builder()->BinaryOp(HloOpcode::kPower, lhs, rhs,
+ broadcast_dimensions);
}
XlaOp IsFinite(const XlaOp& operand) {
- return operand.builder()->IsFinite(operand);
+ return operand.builder()->UnaryOp(HloOpcode::kIsFinite, operand);
}
XlaOp ConvertElementType(const XlaOp& operand, PrimitiveType new_element_type) {
@@ -3302,7 +3155,9 @@
return operand.builder()->BitcastConvertType(operand, new_element_type);
}
-XlaOp Neg(const XlaOp& operand) { return operand.builder()->Neg(operand); }
+XlaOp Neg(const XlaOp& operand) {
+ return operand.builder()->UnaryOp(HloOpcode::kNegate, operand);
+}
XlaOp Transpose(const XlaOp& operand, absl::Span<const int64> permutation) {
return operand.builder()->Transpose(operand, permutation);
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index c429035..ea4dc18 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -315,38 +315,6 @@
XlaOp ConstantLiteral(const LiteralSlice& literal);
- template <typename NativeT>
- XlaOp ConstantR0(NativeT value);
- template <typename NativeT>
- XlaOp ConstantR1(absl::Span<const NativeT> values);
- XlaOp ConstantR1(const tensorflow::core::Bitmap& values);
- template <typename NativeT>
- XlaOp ConstantR2(
- std::initializer_list<std::initializer_list<NativeT>> values);
- template <typename NativeT>
- XlaOp ConstantFromArrayWithLayout(const Array<NativeT>& values,
- const Layout& layout);
- template <typename NativeT>
- XlaOp ConstantFromArray(const Array<NativeT>& values);
- template <typename NativeT>
- XlaOp ConstantR2FromArray2DWithLayout(const Array2D<NativeT>& values,
- const Layout& layout);
- template <typename NativeT>
- XlaOp ConstantR2FromArray2D(const Array2D<NativeT>& values);
- template <typename NativeT>
- XlaOp ConstantR3FromArray3DWithLayout(const Array3D<NativeT>& values,
- const Layout& layout);
- template <typename NativeT>
- XlaOp ConstantR3FromArray3D(const Array3D<NativeT>& values);
- template <typename NativeT>
- XlaOp ConstantR4FromArray4DWithLayout(const Array4D<NativeT>& values,
- const Layout& layout);
- template <typename NativeT>
- XlaOp ConstantR4FromArray4D(const Array4D<NativeT>& values);
-
- template <typename NativeT>
- XlaOp ConstantR1(int64 length, NativeT value);
-
XlaOp Broadcast(const XlaOp& operand,
absl::Span<const int64> broadcast_sizes);
@@ -394,24 +362,6 @@
XlaOp GetTupleElement(const XlaOp& tuple_data, int64 index);
- XlaOp Eq(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions = {});
-
- XlaOp Ne(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions = {});
-
- XlaOp Ge(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions = {});
-
- XlaOp Gt(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions = {});
-
- XlaOp Lt(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions = {});
-
- XlaOp Le(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions = {});
-
XlaOp Dot(const XlaOp& lhs, const XlaOp& rhs,
const PrecisionConfig* precision_config = nullptr);
@@ -476,49 +426,6 @@
const Shape& shape_with_layout, const string& opaque,
absl::optional<absl::Span<const Shape>> operand_shapes_with_layout);
- XlaOp Complex(const XlaOp& real, const XlaOp& imag,
- absl::Span<const int64> broadcast_dimensions = {});
-
- XlaOp Conj(const XlaOp& operand);
-
- XlaOp Add(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions = {});
-
- XlaOp Sub(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions = {});
-
- XlaOp Mul(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions = {});
-
- XlaOp Div(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions = {});
-
- XlaOp Rem(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions = {});
-
- XlaOp Max(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions = {});
-
- XlaOp Min(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions = {});
-
- XlaOp And(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions = {});
-
- XlaOp Or(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions = {});
-
- XlaOp Xor(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions = {});
-
- XlaOp Not(const XlaOp& operand);
-
- XlaOp ShiftLeft(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions = {});
- XlaOp ShiftRightArithmetic(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions = {});
- XlaOp ShiftRightLogical(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions = {});
XlaOp Reduce(const XlaOp& operand, const XlaOp& init_value,
const XlaComputation& computation,
@@ -578,44 +485,6 @@
absl::Span<const std::pair<int64, int64>> padding, const XlaOp& source,
const XlaOp& init_value, const XlaComputation& scatter);
- XlaOp Abs(const XlaOp& operand);
-
- XlaOp Atan2(const XlaOp& y, const XlaOp& x,
- absl::Span<const int64> broadcast_dimensions = {});
-
- XlaOp Exp(const XlaOp& operand);
-
- XlaOp Expm1(const XlaOp& operand);
-
- XlaOp Floor(const XlaOp& operand);
-
- XlaOp Ceil(const XlaOp& operand);
-
- XlaOp Round(const XlaOp& operand);
-
- XlaOp Log(const XlaOp& operand);
-
- XlaOp Log1p(const XlaOp& operand);
-
- XlaOp Sign(const XlaOp& operand);
-
- XlaOp Clz(const XlaOp& operand);
-
- XlaOp Cos(const XlaOp& operand);
-
- XlaOp Sin(const XlaOp& operand);
-
- XlaOp Tanh(const XlaOp& operand);
-
- XlaOp Real(const XlaOp& operand);
-
- XlaOp Imag(const XlaOp& operand);
-
- XlaOp Pow(const XlaOp& lhs, const XlaOp& rhs,
- absl::Span<const int64> broadcast_dimensions = {});
-
- XlaOp IsFinite(const XlaOp& operand);
-
XlaOp Iota(const Shape& shape, int64 iota_dimension);
XlaOp Iota(PrimitiveType type, int64 size);
@@ -626,8 +495,6 @@
XlaOp BitcastConvertType(const XlaOp& operand,
PrimitiveType new_element_type);
- XlaOp Neg(const XlaOp& operand);
-
XlaOp Transpose(const XlaOp& operand, absl::Span<const int64> permutation);
XlaOp Rev(const XlaOp& operand, absl::Span<const int64> dimensions);
@@ -825,48 +692,6 @@
const Shape& shape, const string& name);
friend XlaOp ConstantLiteral(XlaBuilder* builder,
const LiteralSlice& literal);
- template <typename NativeT>
- friend XlaOp ConstantR0(XlaBuilder* builder, NativeT value);
- template <typename NativeT>
- friend XlaOp ConstantR1(XlaBuilder* builder,
- absl::Span<const NativeT> values);
- friend XlaOp ConstantR1(XlaBuilder* builder,
- const tensorflow::core::Bitmap& values);
- template <typename NativeT>
- friend XlaOp ConstantR2(
- XlaBuilder* builder,
- std::initializer_list<std::initializer_list<NativeT>> values);
- template <typename NativeT>
- friend XlaOp ConstantFromArrayWithLayout(XlaBuilder* builder,
- const Array<NativeT>& values,
- const Layout& layout);
- template <typename NativeT>
- friend XlaOp ConstantFromArray(XlaBuilder* builder,
- const Array<NativeT>& values);
- template <typename NativeT>
- friend XlaOp ConstantR2FromArray2DWithLayout(XlaBuilder* builder,
- const Array2D<NativeT>& values,
- const Layout& layout);
- template <typename NativeT>
- friend XlaOp ConstantR2FromArray2D(XlaBuilder* builder,
- const Array2D<NativeT>& values);
- template <typename NativeT>
- friend XlaOp ConstantR3FromArray3DWithLayout(XlaBuilder* builder,
- const Array3D<NativeT>& values,
- const Layout& layout);
- template <typename NativeT>
- friend XlaOp ConstantR3FromArray3D(XlaBuilder* builder,
- const Array3D<NativeT>& values);
- template <typename NativeT>
- friend XlaOp ConstantR4FromArray4DWithLayout(XlaBuilder* builder,
- const Array4D<NativeT>& values,
- const Layout& layout);
- template <typename NativeT>
- friend XlaOp ConstantR4FromArray4D(XlaBuilder* builder,
- const Array4D<NativeT>& values);
-
- template <typename NativeT>
- friend XlaOp ConstantR1(XlaBuilder* builder, int64 length, NativeT value);
friend XlaOp Broadcast(const XlaOp& operand,
absl::Span<const int64> broadcast_sizes);
@@ -1970,81 +1795,6 @@
// Implementation details below this point.
//
-template <typename NativeT>
-XlaOp XlaBuilder::ConstantR0(NativeT value) {
- return ConstantLiteral(LiteralUtil::CreateR0<NativeT>(value));
-}
-
-template <typename NativeT>
-XlaOp XlaBuilder::ConstantR1(absl::Span<const NativeT> values) {
- return ConstantLiteral(LiteralUtil::CreateR1<NativeT>(values));
-}
-
-template <typename NativeT>
-XlaOp XlaBuilder::ConstantR1(int64 length, NativeT value) {
- Literal literal(ShapeUtil::MakeShape(
- primitive_util::NativeToPrimitiveType<NativeT>(), {length}));
- literal.PopulateWithValue(value);
- return ConstantLiteral(literal);
-}
-
-inline XlaOp XlaBuilder::ConstantR1(const tensorflow::core::Bitmap& values) {
- return ConstantLiteral(LiteralUtil::CreateR1(values));
-}
-
-template <typename NativeT>
-XlaOp XlaBuilder::ConstantR2(
- std::initializer_list<std::initializer_list<NativeT>> values) {
- return ConstantLiteral(LiteralUtil::CreateR2<NativeT>(values));
-}
-
-template <typename NativeT>
-XlaOp XlaBuilder::ConstantFromArrayWithLayout(const Array<NativeT>& values,
- const Layout& layout) {
- return ConstantLiteral(
- LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
-}
-
-template <typename NativeT>
-XlaOp XlaBuilder::ConstantFromArray(const Array<NativeT>& values) {
- return ConstantLiteral(LiteralUtil::CreateFromArray<NativeT>(values));
-}
-
-template <typename NativeT>
-XlaOp XlaBuilder::ConstantR2FromArray2DWithLayout(
- const Array2D<NativeT>& values, const Layout& layout) {
- return ConstantLiteral(
- LiteralUtil::CreateFromArrayWithLayout<NativeT>(values, layout));
-}
-
-template <typename NativeT>
-XlaOp XlaBuilder::ConstantR2FromArray2D(const Array2D<NativeT>& values) {
- return ConstantLiteral(LiteralUtil::CreateR2FromArray2D<NativeT>(values));
-}
-
-template <typename NativeT>
-XlaOp XlaBuilder::ConstantR3FromArray3DWithLayout(
- const Array3D<NativeT>& values, const Layout& layout) {
- return ConstantLiteral(
- LiteralUtil::CreateR3FromArray3DWithLayout<NativeT>(values, layout));
-}
-
-template <typename NativeT>
-XlaOp XlaBuilder::ConstantR3FromArray3D(const Array3D<NativeT>& values) {
- return ConstantFromArray(values);
-}
-
-template <typename NativeT>
-XlaOp XlaBuilder::ConstantR4FromArray4DWithLayout(
- const Array4D<NativeT>& values, const Layout& layout) {
- return ConstantFromArrayWithLayout(values, layout);
-}
-
-template <typename NativeT>
-XlaOp XlaBuilder::ConstantR4FromArray4D(const Array4D<NativeT>& values) {
- return ConstantFromArray(values);
-}
-
// Free function template implementations.
template <typename NativeT>
diff --git a/tensorflow/compiler/xla/client/xla_builder_test.cc b/tensorflow/compiler/xla/client/xla_builder_test.cc
index 368a945..c9fa738 100644
--- a/tensorflow/compiler/xla/client/xla_builder_test.cc
+++ b/tensorflow/compiler/xla/client/xla_builder_test.cc
@@ -611,6 +611,29 @@
<< result_shape;
}
+TEST_F(XlaBuilderTest, DynamicSelectOnlyPredDynamic) {
+ XlaBuilder b(TestName());
+ Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShape(PRED, {10}), ShapeUtil::MakeShape(F32, {10}),
+ ShapeUtil::MakeShape(U32, {})});
+ auto p0 = Parameter(&b, 0, tuple_param_shape, "p0");
+ ASSERT_IS_OK(b.SetDynamicBinding(/*dynamic_size_param_num=*/0,
+ /*dynamic_size_param_index=*/{1},
+ /*target_param_num=*/0,
+ /*target_param_index=*/{0},
+ /*target_dim_num=*/0));
+ auto gte0 = GetTupleElement(p0, 0);
+ auto gte1 = GetTupleElement(p0, 1);
+
+ Select(gte0, gte1, gte1);
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, BuildHloModule(&b));
+ const Shape& result_shape =
+ module->entry_computation()->root_instruction()->shape();
+ EXPECT_TRUE(ContainersEqual(result_shape.dynamic_dimensions(), {true}))
+ << result_shape;
+}
+
TEST_F(XlaBuilderTest, DynamicPad) {
XlaBuilder b(TestName());
Shape tuple_param_shape = ShapeUtil::MakeTupleShape(
diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md
index 363fd17..db90d18 100644
--- a/tensorflow/compiler/xla/g3doc/operation_semantics.md
+++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md
@@ -1186,7 +1186,7 @@
<b>`Sign(operand)`</b> Element-wise sign operation `x -> sgn(x)` where
-$$\text{sgn}(x) = \begin{cases} -1 & x < 0\\ 0 & x = 0\\ 1 & x > 0 \end{cases}$$
+$$\text{sgn}(x) = \begin{cases} -1 & x < 0\\ -0 & x = -0\\ NaN & x = NaN\\ +0 & x = +0\\ 1 & x > 0 \end{cases}$$
using the comparison operator of the element type of `operand`.
diff --git a/tensorflow/compiler/xla/index_util.cc b/tensorflow/compiler/xla/index_util.cc
index 7e22a32..eebd824 100644
--- a/tensorflow/compiler/xla/index_util.cc
+++ b/tensorflow/compiler/xla/index_util.cc
@@ -21,7 +21,6 @@
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc
index 8600e87..bb9bca0 100644
--- a/tensorflow/compiler/xla/literal.cc
+++ b/tensorflow/compiler/xla/literal.cc
@@ -1628,26 +1628,20 @@
return true;
}
- auto piece_is_all = [&]() {
- switch (shape().element_type()) {
- case F32:
- return AllElementsEqualValue<float>(piece.data<float>(), value);
- case F64:
- return AllElementsEqualValue<double>(piece.data<double>(), value);
- case F16:
- return AllElementsEqualValue<half>(piece.data<half>(),
- static_cast<half>(value));
- case BF16:
- return AllElementsEqualValue<bfloat16>(
- piece.data<bfloat16>(), static_cast<bfloat16>(value));
- default:
- return false;
- }
- };
- if (!piece_is_all()) {
- return false;
+ switch (shape().element_type()) {
+ case F32:
+ return AllElementsEqualValue<float>(piece.data<float>(), value);
+ case F64:
+ return AllElementsEqualValue<double>(piece.data<double>(), value);
+ case F16:
+ return AllElementsEqualValue<half>(piece.data<half>(),
+ static_cast<half>(value));
+ case BF16:
+ return AllElementsEqualValue<bfloat16>(
+ piece.data<bfloat16>(), static_cast<bfloat16>(value));
+ default:
+ return false;
}
- return true;
});
}
diff --git a/tensorflow/compiler/xla/literal.h b/tensorflow/compiler/xla/literal.h
index 368e460..c418be8 100644
--- a/tensorflow/compiler/xla/literal.h
+++ b/tensorflow/compiler/xla/literal.h
@@ -958,12 +958,15 @@
void MutableLiteralBase::AppendSparseElement(
absl::Span<const int64> multi_index, NativeT value,
const ShapeIndex& shape_index) {
- // TODO(jlebar): CHECK that multi_index is in range?
Piece& p = piece(shape_index);
const Shape& subshape = p.subshape();
CHECK(LayoutUtil::IsSparseArray(subshape));
int64 rank = subshape.rank();
CHECK_EQ(multi_index.size(), rank);
+ for (int64 i = 0; i < rank; ++i) {
+ CHECK_GE(multi_index[i], 0);
+ CHECK_LT(multi_index[i], subshape.dimensions(i));
+ }
int64 last_element = p.sparse_indices()->index_count();
CHECK_LT(last_element, LayoutUtil::MaxSparseElements(subshape.layout()));
p.sparse_indices()->Append(multi_index);
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc
index ce4bd6f..0891380 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.cc
+++ b/tensorflow/compiler/xla/python/local_computation_builder.cc
@@ -92,7 +92,12 @@
"previously created with a platform name of %s.",
platform_name, *g_platform_name);
}
- TF_RETURN_IF_ERROR(PlatformUtil::GetPlatform(platform_name).status());
+ TF_ASSIGN_OR_RETURN(se::Platform * platform,
+ PlatformUtil::GetPlatform(platform_name));
+ if (platform->VisibleDeviceCount() <= 0) {
+ return InvalidArgument("Platform %s has no visible devices.",
+ platform_name);
+ }
*g_platform_name = platform_name;
return Status::OK();
}
@@ -923,13 +928,11 @@
});
}
-LocalOp LocalComputationBuilder::TriangularSolve(const LocalOp& a,
- const LocalOp& b,
- bool left_side, bool lower,
- bool transpose_a,
- bool conjugate_a) {
+LocalOp LocalComputationBuilder::TriangularSolve(
+ const LocalOp& a, const LocalOp& b, bool left_side, bool lower,
+ bool transpose_a, bool conjugate_a, bool unit_diagonal) {
return xla::TriangularSolve(a.op(), b.op(), left_side, lower, transpose_a,
- conjugate_a);
+ conjugate_a, unit_diagonal);
}
LocalOp LocalComputationBuilder::Gather(
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h
index e3af88f..5e65ecf 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.h
+++ b/tensorflow/compiler/xla/python/local_computation_builder.h
@@ -425,7 +425,8 @@
LocalOp Cholesky(const LocalOp& a);
LocalOp TriangularSolve(const LocalOp& a, const LocalOp& b, bool left_side,
- bool lower, bool transpose_a, bool conjugate_a);
+ bool lower, bool transpose_a, bool conjugate_a,
+ bool unit_diagonal);
LocalOp Gather(const LocalOp& input, const LocalOp& start_indices,
const GatherDimensionNumbers& dimension_numbers,
diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i
index 7b2f69d..df2ab0b 100644
--- a/tensorflow/compiler/xla/python/local_computation_builder.i
+++ b/tensorflow/compiler/xla/python/local_computation_builder.i
@@ -453,16 +453,6 @@
// Literal
-%typemap(out) StatusOr<Literal> {
- if ($1.ok()) {
- Literal value = $1.ConsumeValueOrDie();
- $result = numpy::PyObjectFromXlaLiteral(*value);
- } else {
- PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
- SWIG_fail;
- }
-}
-
%typemap(in) const Literal& (StatusOr<Literal> literal_status) {
literal_status = numpy::XlaLiteralFromPyObject($input);
if (!literal_status.ok()) {
@@ -472,16 +462,26 @@
$1 = &literal_status.ValueOrDie();
}
-%typemap(out) Literal {
- $result = numpy::PyObjectFromXlaLiteral(*$1);
+%typemap(out) Literal (StatusOr<numpy::Safe_PyObjectPtr> obj_status) {
+ obj_status = numpy::PyObjectFromXlaLiteral(*$1);
+ if (!obj_status.ok()) {
+ PyErr_SetString(PyExc_RuntimeError, obj_status.status().ToString().c_str());
+ SWIG_fail;
+ }
+ $result = obj_status.ValueOrDie().release();
}
-%typemap(out) StatusOr<Literal> {
+%typemap(out) StatusOr<Literal> (StatusOr<numpy::Safe_PyObjectPtr> obj_status) {
if (!$1.ok()) {
PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str());
SWIG_fail;
}
- $result = numpy::PyObjectFromXlaLiteral($1.ValueOrDie());
+ obj_status = numpy::PyObjectFromXlaLiteral($1.ValueOrDie());
+ if (!obj_status.ok()) {
+ PyErr_SetString(PyExc_RuntimeError, obj_status.status().ToString().c_str());
+ SWIG_fail;
+ }
+ $result = obj_status.ValueOrDie().release();
}
%typemap(in) const std::vector<Literal>& (std::vector<Literal> temps) {
diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc
index 52c5c62..8e056f9 100644
--- a/tensorflow/compiler/xla/python/numpy_bridge.cc
+++ b/tensorflow/compiler/xla/python/numpy_bridge.cc
@@ -26,6 +26,10 @@
namespace numpy {
+Safe_PyObjectPtr make_safe(PyObject* object) {
+ return Safe_PyObjectPtr(object);
+}
+
int PrimitiveTypeToNumpyType(PrimitiveType primitive_type) {
switch (primitive_type) {
case PRED:
@@ -349,13 +353,17 @@
return result;
}
-PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal) {
+StatusOr<Safe_PyObjectPtr> PyObjectFromXlaLiteral(const LiteralSlice& literal) {
if (literal.shape().IsTuple()) {
int num_elements = ShapeUtil::TupleElementCount(literal.shape());
- PyObject* tuple = PyTuple_New(num_elements);
+ std::vector<Safe_PyObjectPtr> elems(num_elements);
for (int i = 0; i < num_elements; i++) {
- PyTuple_SET_ITEM(tuple, i,
- PyObjectFromXlaLiteral(LiteralSlice(literal, {i})));
+ TF_ASSIGN_OR_RETURN(elems[i],
+ PyObjectFromXlaLiteral(LiteralSlice(literal, {i})));
+ }
+ Safe_PyObjectPtr tuple = make_safe(PyTuple_New(num_elements));
+ for (int i = 0; i < num_elements; i++) {
+ PyTuple_SET_ITEM(tuple.get(), i, elems[i].release());
}
return tuple;
} else {
@@ -365,10 +373,10 @@
dimensions[i] = ShapeUtil::GetDimension(literal.shape(), i);
}
int np_type = PrimitiveTypeToNumpyType(literal.shape().element_type());
- PyObject* array =
- PyArray_EMPTY(rank, dimensions.data(), np_type, /*fortran=*/0);
- CopyLiteralToNumpyArray(np_type, literal,
- reinterpret_cast<PyArrayObject*>(array));
+ Safe_PyObjectPtr array = make_safe(
+ PyArray_EMPTY(rank, dimensions.data(), np_type, /*fortran=*/0));
+ TF_RETURN_IF_ERROR(CopyLiteralToNumpyArray(
+ np_type, literal, reinterpret_cast<PyArrayObject*>(array.get())));
return array;
}
}
@@ -408,6 +416,12 @@
case NPY_BOOL:
CopyNumpyArrayToLiteral<bool>(py_array, literal);
break;
+ case NPY_INT8:
+ CopyNumpyArrayToLiteral<int8>(py_array, literal);
+ break;
+ case NPY_INT16:
+ CopyNumpyArrayToLiteral<int16>(py_array, literal);
+ break;
case NPY_INT32:
CopyNumpyArrayToLiteral<int32>(py_array, literal);
break;
@@ -417,6 +431,9 @@
case NPY_UINT8:
CopyNumpyArrayToLiteral<uint8>(py_array, literal);
break;
+ case NPY_UINT16:
+ CopyNumpyArrayToLiteral<uint16>(py_array, literal);
+ break;
case NPY_UINT32:
CopyNumpyArrayToLiteral<uint32>(py_array, literal);
break;
@@ -445,12 +462,18 @@
return Status::OK();
}
-void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
- PyArrayObject* py_array) {
+Status CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
+ PyArrayObject* py_array) {
switch (np_type) {
case NPY_BOOL:
CopyLiteralToNumpyArray<bool>(literal, py_array);
break;
+ case NPY_INT8:
+ CopyLiteralToNumpyArray<int8>(literal, py_array);
+ break;
+ case NPY_INT16:
+ CopyLiteralToNumpyArray<int16>(literal, py_array);
+ break;
case NPY_INT32:
CopyLiteralToNumpyArray<int32>(literal, py_array);
break;
@@ -460,6 +483,9 @@
case NPY_UINT8:
CopyLiteralToNumpyArray<uint8>(literal, py_array);
break;
+ case NPY_UINT16:
+ CopyLiteralToNumpyArray<uint16>(literal, py_array);
+ break;
case NPY_UINT32:
CopyLiteralToNumpyArray<uint32>(literal, py_array);
break;
@@ -482,8 +508,10 @@
CopyLiteralToNumpyArray<complex128>(literal, py_array);
break;
default:
- LOG(FATAL) << "No XLA literal container for Numpy type" << np_type;
+ return InvalidArgument(
+ "No XLA literal container for Numpy type number: %d", np_type);
}
+ return Status::OK();
}
PyObject* LongToPyIntOrPyLong(long x) { // NOLINT
diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h
index 40ff2d9..737fc4b 100644
--- a/tensorflow/compiler/xla/python/numpy_bridge.h
+++ b/tensorflow/compiler/xla/python/numpy_bridge.h
@@ -36,6 +36,16 @@
namespace numpy {
+struct PyDecrefDeleter {
+ void operator()(PyObject* p) const { Py_DECREF(p); }
+};
+
+// Safe container for an owned PyObject. On destruction, the reference count of
+// the contained object will be decremented.
+using Safe_PyObjectPtr = std::unique_ptr<PyObject, PyDecrefDeleter>;
+
+Safe_PyObjectPtr make_safe(PyObject* object);
+
// Maps XLA primitive types (PRED, S8, F32, ..., and TUPLE) to numpy
// dtypes (NPY_BOOL, NPY_INT8, NPY_FLOAT32, ..., and NPY_OBJECT), and
// vice versa.
@@ -74,7 +84,7 @@
// array data.
//
// The return value is a new reference.
-PyObject* PyObjectFromXlaLiteral(const LiteralSlice& literal);
+StatusOr<Safe_PyObjectPtr> PyObjectFromXlaLiteral(const LiteralSlice& literal);
// Converts a Numpy ndarray or a nested Python tuple thereof to a
// corresponding XLA literal.
@@ -90,8 +100,8 @@
Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array,
Literal* literal);
-void CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
- PyArrayObject* py_array);
+Status CopyLiteralToNumpyArray(int np_type, const LiteralSlice& literal,
+ PyArrayObject* py_array);
template <typename NativeT>
void CopyNumpyArrayToLiteral(PyArrayObject* py_array, Literal* literal) {
diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py
index 37cae0e..6a7a27d 100644
--- a/tensorflow/compiler/xla/python/xla_client.py
+++ b/tensorflow/compiler/xla/python/xla_client.py
@@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""An in-process, local XLA client in Python, supporting AOT compilation."""
+"""An XLA client in Python, supporting AOT compilation."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import abc
import collections
import enum # pylint: disable=g-bad-import-order
import inspect
@@ -49,13 +50,123 @@
OpMetadata = collections.namedtuple('OpMetadata', _OP_METADATA_FIELDS)
+@six.add_metaclass(abc.ABCMeta)
+class Backend(object):
+ """Abstract base class for XLA backends."""
+
+ @abc.abstractmethod
+ def buffer_from_pyval(self, pyval, device=0):
+ """Allocates a fresh buffer and populates it with `pyval`."""
+
+ @abc.abstractmethod
+ def delete_buffer(self, c_buffer):
+ """Deletes buffer `c_buffer`."""
+
+ @abc.abstractmethod
+ def destructure_tuple(self, c_buffer):
+ """Destructures a tuple buffer into a sequence of buffers."""
+
+ @abc.abstractmethod
+ def compile(self, computation, argument_shapes, compile_options):
+ """Compiles a computation. Returns an executable."""
+
+ @abc.abstractmethod
+ def delete_executable(self, executable):
+ """Deletes an executable."""
+
+ @abc.abstractmethod
+ def execute(self, executable, args):
+ """Runs an executable without replication."""
+
+ @abc.abstractmethod
+ def execute_replicated(self, executable, per_replica_args):
+ """Runs an executable in a replicated manner."""
+
+
+class XlaLocalBackend(Backend):
+ """XLA backend implemented using the in-process xla::LocalClient API."""
+
+ def buffer_from_pyval(self, pyval, device=0):
+ return c_api.LocalShapedBuffer.FromLiteral(pyval, None, device)
+
+ def delete_buffer(self, c_buffer):
+ c_api.DeleteLocalShapedBuffer(c_buffer)
+
+ def destructure_tuple(self, c_buffer):
+ result = c_api.DestructureLocalShapedBufferTuple(c_buffer)
+ return [result.Release(i) for i in xrange(result.size())]
+
+ def compile(self, c_computation, argument_shapes, compile_options):
+ return c_computation.Compile(argument_shapes, compile_options)
+
+ def delete_executable(self, executable):
+ assert isinstance(executable, c_api.CompiledLocalComputation)
+ c_api.DeleteCompiledLocalComputation(executable)
+
+ def execute(self, executable, args):
+ return executable.Execute(args)
+
+ def execute_replicated(self, executable, per_replica_args):
+ output_buffer_tup = executable.ExecutePerReplica(per_replica_args)
+ size = output_buffer_tup.size()
+ return [output_buffer_tup.Release(i) for i in xrange(size)]
+
+
+class XrtBackend(Backend):
+ """XLA backend implemented using XRT."""
+
+ def __init__(self, target):
+ self.target = target
+
+ def buffer_from_pyval(self, pyval, device=0):
+ if device != 0:
+ raise NotImplementedError(
+ 'Multi-replica execution is not yet supported via the XRT backend.')
+ return c_api.XrtAllocation.FromLiteral(pyval,
+ _maybe_encode_string(self.target))
+
+ def delete_buffer(self, c_buffer):
+ c_api.DeleteXrtAllocation(c_buffer)
+
+ def destructure_tuple(self, c_buffer):
+ result = c_api.DestructureXrtAllocationTuple(
+ c_buffer, _maybe_encode_string(self.target))
+ return [result.Release(i) for i in xrange(result.size())]
+
+ def compile(self, c_computation, argument_shapes, compile_options):
+ return c_computation.CompileForXrt(argument_shapes,
+ _maybe_encode_string(self.target))
+
+ def delete_executable(self, executable):
+ assert isinstance(executable, c_api.CompiledXrtComputation)
+ c_api.DeleteCompiledXrtComputation(executable)
+
+ def execute(self, executable, args):
+ return executable.Execute(args)
+
+ def execute_replicated(self, executable, per_replica_args):
+ if len(per_replica_args) != 1:
+ raise NotImplementedError(
+ 'Multi-replica execution is not yet supported via the XRT backend.')
+ return [executable.Execute(per_replica_args[0])]
+
+
+XLA_LOCAL_BACKEND = XlaLocalBackend()
+
+
class BackendType(enum.Enum):
XLA_LOCAL = 1
XRT = 2
-BackendSpec = collections.namedtuple('Backend', ('backend_type', 'target'))
-XLA_LOCAL_BACKEND = BackendSpec(BackendType.XLA_LOCAL, 'local')
+def BackendSpec(backend, target):
+ """Compatibility wrapper to support older clients. Do not use in new code."""
+ if backend == BackendType.XLA_LOCAL:
+ return XLA_LOCAL_BACKEND
+ elif backend == BackendType.XRT:
+ return XrtBackend(target)
+ else:
+ raise ValueError('Unknown backend {}'.format(backend))
def OpMetadataToProto(pyobj):
@@ -227,10 +338,6 @@
self.c_buffer = c_buffer
self._backend = backend
self._replica = replica
- if backend.backend_type == BackendType.XRT:
- self._delete = c_api.DeleteXrtAllocation
- else:
- self._delete = c_api.DeleteLocalShapedBuffer
@staticmethod
def from_pyval(pyval, replica=0, backend=XLA_LOCAL_BACKEND):
@@ -241,14 +348,7 @@
raise ValueError(
'Attempt to place buffer on replica {} when the replica count is {}'
.format(replica, num_replicas))
- if backend.backend_type == BackendType.XRT:
- if replica != 0:
- raise NotImplementedError(
- 'Multi-replica execution is not yet supported via the XRT backend.')
- cbuf = c_api.XrtAllocation.FromLiteral(
- pyval, _maybe_encode_string(backend.target))
- else:
- cbuf = c_api.LocalShapedBuffer.FromLiteral(pyval, None, replica)
+ cbuf = backend.buffer_from_pyval(pyval, replica)
return LocalBuffer(cbuf, backend, replica)
def to_py(self):
@@ -262,24 +362,17 @@
def delete(self):
if self.c_buffer is not None:
- self._delete(self.c_buffer)
+ self._backend.delete_buffer(self.c_buffer)
self.c_buffer = None
def destructure(self):
"""Assuming a tuple buffer, unpack it into constituent tuple elements."""
assert self.c_buffer is not None
- if self._backend.backend_type == BackendType.XRT:
- result = c_api.DestructureXrtAllocationTuple(
- self.c_buffer, _maybe_encode_string(self._backend.target))
- else:
- result = c_api.DestructureLocalShapedBufferTuple(self.c_buffer)
+ result = self._backend.destructure_tuple(self.c_buffer)
self.delete()
- size = result.size()
- destructured = tuple(
- LocalBuffer(
- result.Release(i), replica=self._replica, backend=self._backend)
- for i in xrange(size))
- return destructured
+ return tuple(
+ LocalBuffer(sub_buffer, replica=self._replica, backend=self._backend)
+ for sub_buffer in result)
def is_deleted(self):
return self.c_buffer is None
@@ -428,6 +521,20 @@
updated._check_minor_to_major() # pylint: disable=protected-access
return updated
+ def serialize(self, proto):
+ """Serializes 'shape' into proto."""
+ if self.is_tuple():
+ proto.element_type = xla_data_pb2.TUPLE
+ for shape in self.tuple_shapes():
+ shape.serialize(proto.tuple_shapes.add())
+ else:
+ proto.element_type = dtype_to_etype(self.element_type())
+ proto.dimensions.extend(self.dimensions())
+ proto.is_dynamic_dimension.extend([False for _ in self.dimensions()])
+ if self.minor_to_major():
+ proto.layout.format = xla_data_pb2.DENSE
+ proto.layout.minor_to_major.extend(self.minor_to_major())
+
def _wrap_shape(shape_info):
dtype, dims = shape_info
@@ -509,18 +616,6 @@
self._backend = backend
self._is_compiled = is_compiled
- # Ensure a reference to C-based destructor for use in __del__.
- if is_compiled:
- if backend.backend_type == BackendType.XRT:
- assert isinstance(c_computation, c_api.CompiledXrtComputation)
- self._delete = c_api.DeleteCompiledXrtComputation
- else:
- assert isinstance(c_computation, c_api.CompiledLocalComputation)
- self._delete = c_api.DeleteCompiledLocalComputation
- else:
- assert isinstance(c_computation, c_api.LocalComputation)
- self._delete = c_api.DeleteLocalComputation
-
@property
def computation(self):
if self._is_compiled:
@@ -574,11 +669,8 @@
compile_options = compile_options or CompileOptions()
compile_options.result_shape = result_shape
- if self._backend.backend_type == BackendType.XRT:
- c = self.computation.CompileForXrt(
- argument_shapes, _maybe_encode_string(self._backend.target))
- else:
- c = self.computation.Compile(argument_shapes, compile_options)
+ c = self._backend.compile(self.computation, argument_shapes,
+ compile_options)
return LocalComputation(c, is_compiled=True, backend=self._backend)
def CompileWithExampleArguments(self,
@@ -598,7 +690,7 @@
if check_for_deleted_args and any(arg.is_deleted() for arg in arguments):
raise ValueError('Executing with deleted local buffer argument')
raw_args = [arg.c_buffer for arg in arguments]
- output_buffer = self._c_computation.Execute(raw_args)
+ output_buffer = self._backend.execute(self._c_computation, raw_args)
return LocalBuffer(output_buffer, backend=self._backend, replica=0)
def ExecutePerReplica(self, arguments=None):
@@ -636,15 +728,8 @@
]
# Execute
- if self._backend.backend_type == BackendType.XRT:
- if len(stripped_args) > 1:
- raise NotImplementedError(
- 'Multi-replica execution is not yet supported via the XRT backend.')
- output_buffers = [self._c_computation.Execute(stripped_args[0])]
- else:
- output_buffer_tup = self._c_computation.ExecutePerReplica(stripped_args)
- size = output_buffer_tup.size()
- output_buffers = [output_buffer_tup.Release(i) for i in xrange(size)]
+ output_buffers = self._backend.execute_replicated(
+ self._c_computation, stripped_args)
# Wrap output handles in LocalBuffer instances
return tuple(
@@ -672,7 +757,12 @@
return [out.to_py() for out in self.ExecutePerReplica(arguments)]
def __del__(self):
- self._delete(self._c_computation)
+ # Ensure a reference to C-based destructor for use in __del__.
+ if self._is_compiled:
+ self._backend.delete_executable(self._c_computation)
+ else:
+ assert isinstance(self._c_computation, c_api.LocalComputation)
+ c_api.DeleteLocalComputation(self._c_computation)
def _make_replica_group_proto(replica_group):
@@ -1523,11 +1613,17 @@
"""Enqueues a QR decomposition onto the computation."""
return self._client.QR(a, full_matrices)
- def TriangularSolve(self, a, b, left_side=False, lower=False,
- transpose_a=False, conjugate_a=False):
+ def TriangularSolve(self,
+ a,
+ b,
+ left_side=False,
+ lower=False,
+ transpose_a=False,
+ conjugate_a=False,
+ unit_diagonal=False):
"""Enqueues a triangular-solve operation onto the computation."""
- return self._client.TriangularSolve(
- a, b, left_side, lower, transpose_a, conjugate_a)
+ return self._client.TriangularSolve(a, b, left_side, lower, transpose_a,
+ conjugate_a, unit_diagonal)
def Gather(self, a, start_indices, dimension_numbers, slice_sizes):
"""Enqueues a Gather operation onto the computation."""
@@ -1602,6 +1698,8 @@
Raises:
A runtime exception if the XLA service has already been initialized.
+ A runtime exception if the platform does not exist, or there are no devices
+ with that platform.
"""
platform_name = _maybe_encode_string(platform_name)
c_api.InitializePlatformName(platform_name)
diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py
index c80e792..aa38c06 100644
--- a/tensorflow/compiler/xla/python/xla_client_test.py
+++ b/tensorflow/compiler/xla/python/xla_client_test.py
@@ -88,6 +88,12 @@
class ComputationsWithConstantsTest(LocalComputationTest):
"""Tests focusing on Constant ops."""
+ def testConstantScalarSumS8(self):
+ c = self._NewComputation()
+ root = c.Add(c.Constant(np.int8(1)), c.Constant(np.int8(2)))
+ self.assertEqual(c.GetShape(root), c.GetReturnValueShape())
+ self._ExecuteAndCompareExact(c, expected=np.int8(3))
+
def testConstantScalarSumF32(self):
c = self._NewComputation()
root = c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14))
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 4f6509c..a0cdf34 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -114,6 +114,7 @@
":bfloat16_normalization",
":bfloat16_support",
":hlo",
+ ":hlo_creation_utils",
":hlo_verifier",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -1203,7 +1204,6 @@
":tuple_points_to_analysis",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
- "//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
@@ -1602,6 +1602,7 @@
":algebraic_simplifier",
":hlo",
":hlo_casting_utils",
+ ":hlo_creation_utils",
":hlo_matchers",
":hlo_parser",
":hlo_pass",
@@ -2306,6 +2307,7 @@
srcs = ["hlo_dataflow_analysis_test.cc"],
deps = [
":hlo",
+ ":hlo_creation_utils",
":hlo_dataflow_analysis",
":hlo_graph_dumper",
":hlo_matchers",
@@ -2476,6 +2478,7 @@
srcs = ["tuple_points_to_analysis_test.cc"],
deps = [
":hlo",
+ ":hlo_creation_utils",
":hlo_matchers",
":instruction_fusion",
":tuple_points_to_analysis",
@@ -2851,7 +2854,6 @@
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
@@ -3029,8 +3031,6 @@
":hlo_pass",
":shape_inference",
"//tensorflow/compiler/xla:literal_util",
- "//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
],
)
@@ -3215,6 +3215,7 @@
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
],
)
@@ -3273,7 +3274,6 @@
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
],
)
@@ -3585,7 +3585,6 @@
":while_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
- "//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:inlined_vector",
],
@@ -3641,7 +3640,6 @@
":hlo_evaluator",
":hlo_pass",
"//tensorflow/compiler/xla:util",
- "//tensorflow/core:lib",
"//tensorflow/core:ptr_util",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
@@ -3804,7 +3802,6 @@
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_pass",
"@com_google_absl//absl/container:flat_hash_map",
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 549e675..acc2c28 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -892,7 +892,6 @@
} // namespace
Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) {
- Shape* shape;
HloInstruction *a, *b, *c, *d;
CHECK(Match(divide, m::Divide(m::Op(&a), m::Op(&b))));
// A/1 => A
@@ -955,6 +954,7 @@
break;
}
+ Shape* shape;
// exp(A)/exp(B) => exp(A-B)
if (Match(divide, m::Divide(m::Exp(m::Op(&a)), m::Exp(m::Op(&b)))
.WithShape(m::Shape(&shape)))) {
@@ -1005,8 +1005,9 @@
// (Backends can do this transformation, but generally only if the constant is
// a scalar.)
if (Match(divide, m::Divide(m::NonConstant(&a), m::Constant(&b)))) {
- Literal new_literal(b->shape());
- switch (b->shape().element_type()) {
+ Shape result_shape = b->literal().shape();
+ Literal new_literal(result_shape);
+ switch (result_shape.element_type()) {
case F16:
TF_RETURN_IF_ERROR(InvertConstant<half>(*b, &new_literal));
break;
@@ -2427,8 +2428,14 @@
// Reshape directly to empty constant if the shape contains zero-element
// dimension.
if (ShapeUtil::IsZeroElementArray(reshape->shape())) {
+ // If the instruction doesn't have a layout, use a default layout for
+ // the literal result.
+ Shape reshaped_shape = reshape->shape();
+ if (!LayoutUtil::HasLayout(reshaped_shape)) {
+ LayoutUtil::SetToDefaultLayout(&reshaped_shape);
+ }
auto empty_constant = HloInstruction::CreateConstant(
- Literal::CreateFromShape(reshape->shape()));
+ Literal::CreateFromShape(reshaped_shape));
return ReplaceWithNewInstruction(reshape, std::move(empty_constant));
}
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index b5fed23..a908051 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -25,8 +25,10 @@
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
+#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/service/hlo_pass_fix.h"
@@ -45,6 +47,7 @@
using ::testing::ElementsAre;
namespace m = match;
+namespace op = xla::testing::opcode_matchers;
class AlgebraicSimplifierTest : public HloTestBase {
protected:
@@ -2747,12 +2750,13 @@
TEST_F(AlgebraicSimplifierTest, RemoveNoopSort) {
auto builder = HloComputation::Builder(TestName());
+ auto module = CreateNewVerifiedModule();
Shape keys_shape = ShapeUtil::MakeShape(F32, {1});
auto keys = builder.AddInstruction(
HloInstruction::CreateParameter(0, keys_shape, "keys"));
- builder.AddInstruction(HloInstruction::CreateSort(keys_shape, 0, keys));
- auto module = CreateNewVerifiedModule();
+ TF_ASSERT_OK(
+ MakeSortHlo(keys_shape, {keys}, 0, &builder, module.get()).status());
HloComputation* computation = module->AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(default_options_);
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
@@ -2761,6 +2765,7 @@
TEST_F(AlgebraicSimplifierTest, ReplaceEffectiveScalarKeyValueSortWithTuple) {
auto builder = HloComputation::Builder(TestName());
+ auto module = CreateNewVerifiedModule();
Shape keys_shape = ShapeUtil::MakeShape(F32, {5, 0});
Shape values_shape = ShapeUtil::MakeShape(S32, {5, 0});
@@ -2770,10 +2775,10 @@
HloInstruction::CreateParameter(1, values_shape, "values0"));
auto values1 = builder.AddInstruction(
HloInstruction::CreateParameter(2, values_shape, "values1"));
- builder.AddInstruction(HloInstruction::CreateSort(
- ShapeUtil::MakeTupleShape({keys_shape, values_shape, values_shape}), 0,
- keys, {values0, values1}));
- auto module = CreateNewVerifiedModule();
+ TF_ASSERT_OK(MakeSortHlo(ShapeUtil::MakeTupleShape(
+ {keys_shape, values_shape, values_shape}),
+ {keys, values0, values1}, 0, &builder, module.get())
+ .status());
HloComputation* computation = module->AddEntryComputation(builder.Build());
AlgebraicSimplifier simplifier(default_options_);
ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie());
@@ -4765,5 +4770,52 @@
m::Broadcast(m::ConstantScalar(1)))));
}
+TEST_F(AlgebraicSimplifierTest, ZeroSizedReshapeWithoutLayout) {
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* param =
+ builder.AddInstruction(HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShape(F32, {1}), "param"));
+ HloInstruction* broadcast =
+ builder.AddInstruction(HloInstruction::CreateBroadcast(
+ ShapeUtil::MakeShape(F32, {0, 1}), param, {1}));
+
+ // Create a reshape with zero sized result and without layout.
+ Shape reshaped_shape = ShapeUtil::MakeShape(F32, {0});
+ reshaped_shape.clear_layout();
+ builder.AddInstruction(
+ HloInstruction::CreateReshape(reshaped_shape, broadcast));
+
+ std::unique_ptr<VerifiedHloModule> module = CreateNewVerifiedModule();
+ module->AddEntryComputation(builder.Build());
+
+ AlgebraicSimplifierOptions options;
+ AlgebraicSimplifier simplifier(options);
+ EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Constant());
+}
+
+TEST_F(AlgebraicSimplifierTest, DividedByConstantInstructionWithoutLayout) {
+ Shape shape = ShapeUtil::MakeShape(F32, {});
+ shape.clear_layout();
+ auto builder = HloComputation::Builder(TestName());
+ HloInstruction* param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, shape, "param"));
+
+ HloInstruction* const_value = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(20.0f)));
+ builder.AddInstruction(HloInstruction::CreateBinary(shape, HloOpcode::kDivide,
+ param, const_value));
+
+ std::unique_ptr<VerifiedHloModule> module = CreateNewVerifiedModule();
+ module->AddEntryComputation(builder.Build());
+
+ AlgebraicSimplifierOptions options;
+ AlgebraicSimplifier simplifier(options);
+ EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie());
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Multiply());
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/ar_crs_combiner.cc b/tensorflow/compiler/xla/service/ar_crs_combiner.cc
index f8dff6a..99373dc 100644
--- a/tensorflow/compiler/xla/service/ar_crs_combiner.cc
+++ b/tensorflow/compiler/xla/service/ar_crs_combiner.cc
@@ -29,7 +29,6 @@
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
index 551ac4b..2591ff6 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
@@ -16,6 +16,7 @@
#include "tensorflow/compiler/xla/service/bfloat16_normalization.h"
#include "tensorflow/compiler/xla/service/bfloat16_support.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -282,8 +283,10 @@
HloInstruction* value = builder.AddInstruction(
HloInstruction::CreateParameter(1, s32_shape, "value"));
- HloInstruction* sort = builder.AddInstruction(HloInstruction::CreateSort(
- ShapeUtil::MakeTupleShape({bf16_shape, s32_shape}), 0, key, {value}));
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto* sort,
+ MakeSortHlo(ShapeUtil::MakeTupleShape({bf16_shape, s32_shape}),
+ {key, value}, 0, &builder, module.get()));
HloInstruction* gte = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(bf16_shape, sort, 0));
@@ -308,8 +311,10 @@
HloInstruction* value = builder.AddInstruction(
HloInstruction::CreateParameter(1, bf16_shape, "value"));
- HloInstruction* sort = builder.AddInstruction(HloInstruction::CreateSort(
- ShapeUtil::MakeTupleShape({bf16_shape, bf16_shape}), 0, key, {value}));
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto* sort,
+ MakeSortHlo(ShapeUtil::MakeTupleShape({bf16_shape, f32_shape}),
+ {key, value}, 0, &builder, module.get()));
auto computation = module->AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index e1b91b5..cbebbdc 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -191,6 +191,7 @@
case HloOpcode::kReduceWindow:
case HloOpcode::kScatter:
case HloOpcode::kSelectAndScatter:
+ case HloOpcode::kSort:
case HloOpcode::kFusion:
// Map/reduce etc computations are always thread-local.
worklist.push_back(std::make_pair(subcomputation,
diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc
index 94af788..9830475 100644
--- a/tensorflow/compiler/xla/service/call_graph.cc
+++ b/tensorflow/compiler/xla/service/call_graph.cc
@@ -64,6 +64,7 @@
case HloOpcode::kReduceWindow:
case HloOpcode::kScatter:
case HloOpcode::kSelectAndScatter:
+ case HloOpcode::kSort:
case HloOpcode::kFusion:
return CallContext::kParallel;
default:
diff --git a/tensorflow/compiler/xla/service/convolution_group_converter.cc b/tensorflow/compiler/xla/service/convolution_group_converter.cc
index 2ef723f..f11f9e5 100644
--- a/tensorflow/compiler/xla/service/convolution_group_converter.cc
+++ b/tensorflow/compiler/xla/service/convolution_group_converter.cc
@@ -223,7 +223,7 @@
// We are not yet supporting batch_group of sizes greater than 1.
TF_RET_CHECK(input_batch == batch_group_count);
- if (!is_cost_viable_(convolution)) {
+ if (!is_cost_viable_(convolution) || filter_expansion_) {
// We first obtain the expanded the filter (which is the convolution
// output). The batch dimension is the expanded one (which originally
// represents kernel input feature dimension). We mask the filter to zero
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc
index 7fbe0fa..4ac61f4 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc
@@ -17,7 +17,6 @@
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
index 3361a59..fae9670 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc
@@ -29,7 +29,6 @@
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/logging.h"
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
index 4851018..0fecbaf 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
@@ -643,11 +643,13 @@
llvm::Function* function = b_->GetInsertBlock()->getParent();
llvm::Module* module = function->getParent();
- llvm::Function* matmul_func = llvm::cast<llvm::Function>(
- module->getOrInsertFunction(fn_name, matmul_type));
- matmul_func->setCallingConv(llvm::CallingConv::C);
- matmul_func->setDoesNotThrow();
- matmul_func->setOnlyAccessesArgMemory();
+ llvm::FunctionCallee matmul_func =
+ module->getOrInsertFunction(fn_name, matmul_type);
+ if (auto* fn = llvm::dyn_cast<llvm::Function>(matmul_func.getCallee())) {
+ fn->setCallingConv(llvm::CallingConv::C);
+ fn->setDoesNotThrow();
+ fn->setOnlyAccessesArgMemory();
+ }
// The Eigen runtime function expects column-major layout. If the matrices are
// row major, then use the following identity to compute the product:
diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
index c8312d8..0028fba 100644
--- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc
@@ -51,10 +51,11 @@
return Unimplemented("atan2");
}
// Create a function declaration.
- llvm::Function* function =
- llvm::cast<llvm::Function>(module_->getOrInsertFunction(
- llvm_ir::AsStringRef(function_name), lhs->getType(), lhs->getType(),
- rhs->getType()));
+ llvm::Function* function = llvm::dyn_cast<llvm::Function>(
+ module_
+ ->getOrInsertFunction(llvm_ir::AsStringRef(function_name),
+ lhs->getType(), lhs->getType(), rhs->getType())
+ .getCallee());
function->setCallingConv(llvm::CallingConv::C);
function->setDoesNotThrow();
function->setDoesNotAccessMemory();
@@ -85,9 +86,11 @@
return Unimplemented("tanh");
}
// Create a function declaration.
- llvm::Function* function = llvm::cast<llvm::Function>(
- module_->getOrInsertFunction(llvm_ir::AsStringRef(function_name),
- value->getType(), value->getType()));
+ llvm::Function* function = llvm::dyn_cast<llvm::Function>(
+ module_
+ ->getOrInsertFunction(llvm_ir::AsStringRef(function_name),
+ value->getType(), value->getType())
+ .getCallee());
function->setCallingConv(llvm::CallingConv::C);
function->setDoesNotThrow();
function->setDoesNotAccessMemory();
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index b26bfff..efdda85 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -412,11 +412,18 @@
llvm::Function* acquire_func;
if (kind == XfeedKind::kInfeed) {
- acquire_func = llvm::cast<llvm::Function>(module_->getOrInsertFunction(
- runtime::kAcquireInfeedBufferForDequeueSymbolName, acquire_type));
+ acquire_func = llvm::dyn_cast<llvm::Function>(
+ module_
+ ->getOrInsertFunction(
+ runtime::kAcquireInfeedBufferForDequeueSymbolName, acquire_type)
+ .getCallee());
} else {
- acquire_func = llvm::cast<llvm::Function>(module_->getOrInsertFunction(
- runtime::kAcquireOutfeedBufferForPopulationSymbolName, acquire_type));
+ acquire_func = llvm::dyn_cast<llvm::Function>(
+ module_
+ ->getOrInsertFunction(
+ runtime::kAcquireOutfeedBufferForPopulationSymbolName,
+ acquire_type)
+ .getCallee());
}
acquire_func->setCallingConv(llvm::CallingConv::C);
@@ -429,11 +436,19 @@
llvm::Function* release_func;
if (kind == XfeedKind::kInfeed) {
- release_func = llvm::cast<llvm::Function>(module_->getOrInsertFunction(
- runtime::kReleaseInfeedBufferAfterDequeueSymbolName, release_type));
+ release_func = llvm::dyn_cast<llvm::Function>(
+ module_
+ ->getOrInsertFunction(
+ runtime::kReleaseInfeedBufferAfterDequeueSymbolName,
+ release_type)
+ .getCallee());
} else {
- release_func = llvm::cast<llvm::Function>(module_->getOrInsertFunction(
- runtime::kReleaseOutfeedBufferAfterPopulationSymbolName, release_type));
+ release_func = llvm::dyn_cast<llvm::Function>(
+ module_
+ ->getOrInsertFunction(
+ runtime::kReleaseOutfeedBufferAfterPopulationSymbolName,
+ release_type)
+ .getCallee());
}
release_func->setCallingConv(llvm::CallingConv::C);
@@ -629,9 +644,11 @@
b_.getInt8PtrTy()->getPointerTo(), b_.getInt32Ty(),
b_.getInt32Ty()->getPointerTo(), less_than_function->getType()},
/*isVarArg=*/false);
- auto* key_value_sort_func =
- llvm::cast<llvm::Function>(module_->getOrInsertFunction(
- runtime::kKeyValueSortSymbolName, key_value_sort_type));
+ auto* key_value_sort_func = llvm::dyn_cast<llvm::Function>(
+ module_
+ ->getOrInsertFunction(runtime::kKeyValueSortSymbolName,
+ key_value_sort_type)
+ .getCallee());
key_value_sort_func->setCallingConv(llvm::CallingConv::C);
key_value_sort_func->setDoesNotThrow();
llvm::Value* values = llvm_ir::EmitAllocaAtFunctionEntryWithCount(
@@ -1240,8 +1257,8 @@
LOG(WARNING) << "Using Eigen instead of MKL-DNN for single-threaded "
"conv2d function.";
}
- llvm::Function* conv_func = llvm::cast<llvm::Function>(
- module_->getOrInsertFunction(fn_name, conv_type));
+ llvm::Function* conv_func = llvm::dyn_cast<llvm::Function>(
+ module_->getOrInsertFunction(fn_name, conv_type).getCallee());
conv_func->setCallingConv(llvm::CallingConv::C);
conv_func->setDoesNotThrow();
conv_func->setOnlyAccessesArgMemory();
@@ -1324,8 +1341,8 @@
? runtime::kEigenFftSymbolName
: runtime::kEigenSingleThreadedFftSymbolName;
- llvm::Function* fft_func = llvm::cast<llvm::Function>(
- module_->getOrInsertFunction(fn_name, fft_type));
+ llvm::Function* fft_func = llvm::dyn_cast<llvm::Function>(
+ module_->getOrInsertFunction(fn_name, fft_type).getCallee());
fft_func->setCallingConv(llvm::CallingConv::C);
fft_func->setDoesNotThrow();
fft_func->setOnlyAccessesInaccessibleMemOrArgMem();
@@ -2264,13 +2281,15 @@
InBoundsGEP(operands_alloca, {b_.getInt64(i)});
Store(operand_as_i8ptr, slot_in_operands_alloca);
}
- auto* custom_call_ir_function =
- llvm::cast<llvm::Function>(module_->getOrInsertFunction(
- AsStringRef(custom_call_target),
- llvm::FunctionType::get(
- /*Result=*/b_.getVoidTy(),
- /*Params=*/{i8_ptr_type, operands_alloca->getType()},
- /*isVarArg=*/false)));
+ auto* custom_call_ir_function = llvm::dyn_cast<llvm::Function>(
+ module_
+ ->getOrInsertFunction(
+ AsStringRef(custom_call_target),
+ llvm::FunctionType::get(
+ /*Result=*/b_.getVoidTy(),
+ /*Params=*/{i8_ptr_type, operands_alloca->getType()},
+ /*isVarArg=*/false))
+ .getCallee());
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call));
// Write the tuple table if the output is a tuple.
diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc
index adfb839..84a5b05 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_function.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc
@@ -266,9 +266,11 @@
/*Params=*/compute_function_params,
/*isVarArg=*/false);
- llvm::Function* fork_join_func =
- llvm::cast<llvm::Function>(module->getOrInsertFunction(
- runtime::kParallelForkJoinSymbolName, fork_join_type));
+ llvm::Function* fork_join_func = llvm::dyn_cast<llvm::Function>(
+ module
+ ->getOrInsertFunction(runtime::kParallelForkJoinSymbolName,
+ fork_join_type)
+ .getCallee());
fork_join_func->setCallingConv(llvm::CallingConv::C);
fork_join_func->setDoesNotThrow();
diff --git a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc
index a246494..93ef517 100644
--- a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc
+++ b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc
@@ -36,57 +36,88 @@
const char* const kLogV8F32SymbolName = "__xla_cpu_runtime_LogV8F32AVX";
namespace {
-llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module,
- llvm::StringRef function_name,
- int vector_width,
- bool enable_fast_math) {
- llvm::Function* vector_tanh_function = module->getFunction(function_name);
- if (vector_tanh_function == nullptr) {
+
+// Replaces calls to the function `fn_name` with the code generated by
+// fn_body_generator.
+//
+// We assume that fn_name accepts either a scalar f32 or a vector of
+// vector_width f32s, and that fn_body_generator generates a function body with
+// the same inputs/outputs as fn_name.
+void RewriteCalls(
+ llvm::Module* module, const char* fn_name,
+ std::function<llvm::Value*(llvm::IRBuilder<>* b, llvm::Value* input,
+ int32 vector_width)>
+ fn_body_generator,
+ int32 vector_width, bool enable_fast_math) {
+ llvm::Function* fn = module->getFunction(fn_name);
+ if (fn == nullptr) {
// If the function declaration is not present in the module, there can't be
// any calls to resolve. Don't emit the function in this case.
- return nullptr;
+ return;
+ }
+
+ // Our task is to generate a function body for `fn`, but we can't generate a
+ // function body for an LLVM intrinsic. So if fn is an intrinsic, replace it
+ // with a new function.
+ if (fn->isIntrinsic()) {
+ llvm::Function* new_fn = llvm::Function::Create(
+ fn->getFunctionType(), llvm::GlobalValue::InternalLinkage,
+ llvm::Twine("xla_impl.") + fn_name, module);
+ fn->replaceAllUsesWith(new_fn);
+ fn->eraseFromParent();
+ fn = new_fn;
}
llvm::LLVMContext* context = &module->getContext();
- llvm::BasicBlock* vector_tanh_body =
- llvm::BasicBlock::Create(*context, "body", vector_tanh_function);
-
- llvm::IRBuilder<> b(vector_tanh_body);
+ llvm::BasicBlock* fn_body = llvm::BasicBlock::Create(*context, "body", fn);
+ llvm::IRBuilder<> b(fn_body);
llvm::FastMathFlags fast_math_flags;
fast_math_flags.setFast(enable_fast_math);
b.setFastMathFlags(fast_math_flags);
- llvm::Value* input = &*vector_tanh_function->arg_begin();
- CHECK_EQ(vector_width, input->getType()->getVectorNumElements());
- b.CreateRet(llvm_ir::EmitFastTanh(&b, input));
+ llvm::Value* input = &*fn->arg_begin();
- DCHECK(!llvm::verifyFunction(*vector_tanh_function));
- return vector_tanh_function;
+ // Upcast to vector type if input is a scalar.
+ if (vector_width == 1) {
+ llvm::Type* v1_type = llvm::VectorType::get(input->getType(), 1);
+ input = b.CreateInsertElement(llvm::UndefValue::get(v1_type), input,
+ uint64_t{0});
+ }
+
+ // Generate the vectorized code.
+ CHECK_EQ(vector_width, input->getType()->getVectorNumElements());
+ llvm::Value* result = fn_body_generator(&b, input, vector_width);
+
+ // Downcast result to scalar type if necessary.
+ if (vector_width == 1) {
+ result = b.CreateExtractElement(result, uint64_t{0});
+ }
+ b.CreateRet(result);
+ DCHECK(!llvm::verifyFunction(*fn));
+
+ // Force-inline `fn` into all of its callers and then delete `fn`.
+ //
+ // TODO(b/73081976): Should we avoid inlining these in some cases?
+ std::vector<llvm::CallInst*> calls_to_inline;
+ for (auto* user : fn->users()) {
+ calls_to_inline.push_back(llvm::cast<llvm::CallInst>(user));
+ }
+ for (auto* call_to_inline : calls_to_inline) {
+ llvm::InlineFunctionInfo inline_function_info;
+ CHECK(llvm::InlineFunction(call_to_inline, inline_function_info));
+ }
+ fn->eraseFromParent();
}
-llvm::Function* EmitVectorF32ExpIfNeeded(llvm::Module* module,
- llvm::StringRef function_name,
- int vector_width,
- bool enable_fast_math) {
- llvm::Function* vector_exp_function = module->getFunction(function_name);
- if (vector_exp_function == nullptr) {
- // If the function declaration is not present in the module, there can't be
- // any calls to resolve. Don't emit the function in this case.
- return nullptr;
- }
+llvm::Value* GenerateVF32Tanh(llvm::IRBuilder<>* b, llvm::Value* input,
+ int32 /*vector_width*/) {
+ return llvm_ir::EmitFastTanh(b, input);
+}
- llvm::LLVMContext* context = &module->getContext();
-
- llvm::BasicBlock* vector_exp_body =
- llvm::BasicBlock::Create(*context, "body", vector_exp_function);
-
- llvm::IRBuilder<> b(vector_exp_body);
- llvm::FastMathFlags fast_math_flags;
- fast_math_flags.setFast(enable_fast_math);
- b.setFastMathFlags(fast_math_flags);
-
- VectorSupportLibrary vsl(F32, vector_width, &b, "exp_f32");
+llvm::Value* GenerateVF32Exp(llvm::IRBuilder<>* b, llvm::Value* input,
+ int32 vector_width) {
+ VectorSupportLibrary vsl(F32, vector_width, b, "exp_f32");
// This implements the same polynomial approximation as implemented in Eigen3.
@@ -107,7 +138,6 @@
const llvm::APFloat cephes_exp_p4 = GetIeeeF32(1.6666665459E-1);
const llvm::APFloat cephes_exp_p5 = GetIeeeF32(5.0000001201E-1);
- llvm::Value* input = &*vector_exp_function->arg_begin();
llvm::Value* input_clamped =
vsl.Clamp(input, /*low=*/exp_lo, /*high=*/exp_hi);
llvm::Value* fx = vsl.Floor(vsl.MulAdd(input_clamped, cephes_LOG2EF, half));
@@ -128,49 +158,24 @@
// VectorSupportLibrary (intentionally) can't juggle more than one type at a
// time so drop down to IRBuilder for this bit.
llvm::Value* vector_constant_0x7f =
- b.CreateVectorSplat(vector_width, b.getInt32(0x7f));
+ b->CreateVectorSplat(vector_width, b->getInt32(0x7f));
llvm::Value* vector_constant_23 =
- b.CreateVectorSplat(vector_width, b.getInt32(23));
+ b->CreateVectorSplat(vector_width, b->getInt32(23));
llvm::Type* i32_vector_type =
- llvm::VectorType::get(b.getInt32Ty(), vector_width);
+ llvm::VectorType::get(b->getInt32Ty(), vector_width);
// fx is clamped so we don't have to worry about it being out of range for
// i32.
- llvm::Value* emm0 = b.CreateFPToSI(fx, i32_vector_type);
- emm0 = b.CreateAdd(emm0, vector_constant_0x7f);
- emm0 = b.CreateShl(emm0, vector_constant_23);
- llvm::Value* emm0_f32 = b.CreateBitCast(emm0, vsl.vector_type());
+ llvm::Value* emm0 = b->CreateFPToSI(fx, i32_vector_type);
+ emm0 = b->CreateAdd(emm0, vector_constant_0x7f);
+ emm0 = b->CreateShl(emm0, vector_constant_23);
+ llvm::Value* emm0_f32 = b->CreateBitCast(emm0, vsl.vector_type());
- llvm::Value* result = vsl.Max(vsl.Mul(y, emm0_f32), input);
-
- b.CreateRet(result);
-
- DCHECK(!llvm::verifyFunction(*vector_exp_function));
- return vector_exp_function;
+ return vsl.Max(vsl.Mul(y, emm0_f32), input);
}
-llvm::Function* EmitVectorF32LogIfNeeded(llvm::Module* module,
- llvm::StringRef function_name,
- int vector_width,
- bool enable_fast_math) {
- llvm::Function* vector_log_function = module->getFunction(function_name);
- if (vector_log_function == nullptr) {
- // If the function declaration is not present in the module, there can't be
- // any calls to resolve. Don't emit the function in this case.
- return nullptr;
- }
-
- llvm::LLVMContext* context = &module->getContext();
-
- llvm::BasicBlock* vector_log_body =
- llvm::BasicBlock::Create(*context, "body", vector_log_function);
-
- llvm::IRBuilder<> b(vector_log_body);
- llvm::FastMathFlags fast_math_flags;
- fast_math_flags.setFast(enable_fast_math);
- b.setFastMathFlags(fast_math_flags);
-
- llvm::Value* input = &*vector_log_function->arg_begin();
- VectorSupportLibrary vsl(F32, vector_width, &b, "log_f32");
+llvm::Value* GenerateVF32Log(llvm::IRBuilder<>* b, llvm::Value* input,
+ int32 vector_width) {
+ VectorSupportLibrary vsl(F32, vector_width, b, "log_f32");
const llvm::APFloat half = GetIeeeF32(0.5);
const llvm::APFloat one = GetIeeeF32(1.0);
@@ -193,129 +198,107 @@
// The smallest non denormalized float number.
const llvm::APFloat min_norm_pos = GetIeeeF32FromBitwiseRep(0x00800000);
const llvm::APFloat minus_inf = GetIeeeF32FromBitwiseRep(0xff800000);
+ const llvm::APFloat pos_inf = GetIeeeF32FromBitwiseRep(0x7f800000);
const llvm::APFloat inv_mant_mask = GetIeeeF32FromBitwiseRep(~0x7f800000);
// invalid_mask is set if x is negative or NaN (and therefore output
// must be NaN).
llvm::Value* invalid_mask = vsl.FCmpULEMask(input, vsl.GetZeroVector());
- llvm::Value* iszero_mask = vsl.FCmpEQMask(input, vsl.GetZeroVector());
+ llvm::Value* is_zero_mask = vsl.FCmpEQMask(input, vsl.GetZeroVector());
+ llvm::Value* is_pos_inf_mask = vsl.FCmpEQMask(input, pos_inf);
// Cut off denormalized stuff.
- input = vsl.Max(min_norm_pos, input);
+ llvm::Value* tmp0 = vsl.Max(min_norm_pos, input);
// VectorSupportLibrary (intentionally) can't juggle more than one type at a
// time so drop down to IRBuilder for this bit.
llvm::Value* vector_constant_0x7f =
- b.CreateVectorSplat(vector_width, b.getInt32(0x7f));
+ b->CreateVectorSplat(vector_width, b->getInt32(0x7f));
llvm::Value* vector_constant_23 =
- b.CreateVectorSplat(vector_width, b.getInt32(23));
+ b->CreateVectorSplat(vector_width, b->getInt32(23));
llvm::Type* i32_vector_type =
- llvm::VectorType::get(b.getInt32Ty(), vector_width);
+ llvm::VectorType::get(b->getInt32Ty(), vector_width);
- llvm::Value* emm0 =
- b.CreateLShr(b.CreateBitCast(input, i32_vector_type), vector_constant_23);
+ llvm::Value* emm0 = b->CreateLShr(b->CreateBitCast(tmp0, i32_vector_type),
+ vector_constant_23);
// Keep only the fractional part.
- input = vsl.FloatAnd(input, inv_mant_mask);
- input = vsl.FloatOr(input, half);
+ tmp0 = vsl.FloatAnd(tmp0, inv_mant_mask);
+ tmp0 = vsl.FloatOr(tmp0, half);
- emm0 = b.CreateSub(emm0, vector_constant_0x7f);
- llvm::Value* e = vsl.Add(one, b.CreateSIToFP(emm0, vsl.vector_type()));
+ emm0 = b->CreateSub(emm0, vector_constant_0x7f);
+ llvm::Value* e = vsl.Add(one, b->CreateSIToFP(emm0, vsl.vector_type()));
// part2:
// if( x < SQRTHF ) {
// e -= 1;
// x = x + x - 1.0;
// } else { x = x - 1.0; }
- llvm::Value* mask = vsl.FCmpOLTMask(input, cephes_SQRTHF);
- llvm::Value* tmp = vsl.FloatAnd(input, mask);
- input = vsl.Sub(input, one);
+ llvm::Value* mask = vsl.FCmpOLTMask(tmp0, cephes_SQRTHF);
+ llvm::Value* tmp1 = vsl.FloatAnd(tmp0, mask);
+ tmp0 = vsl.Sub(tmp0, one);
e = vsl.Sub(e, vsl.FloatAnd(mask, one));
- input = vsl.Add(input, tmp);
+ tmp0 = vsl.Add(tmp0, tmp1);
- llvm::Value* x2 = vsl.Mul(input, input);
- llvm::Value* x3 = vsl.Mul(x2, input);
+ llvm::Value* x2 = vsl.Mul(tmp0, tmp0);
+ llvm::Value* x3 = vsl.Mul(x2, tmp0);
llvm::Value *y, *y1, *y2;
- y = vsl.MulAdd(input, cephes_log_p0, cephes_log_p1);
- y1 = vsl.MulAdd(input, cephes_log_p3, cephes_log_p4);
- y2 = vsl.MulAdd(input, cephes_log_p6, cephes_log_p7);
- y = vsl.MulAdd(y, input, cephes_log_p2);
- y1 = vsl.MulAdd(y1, input, cephes_log_p5);
- y2 = vsl.MulAdd(y2, input, cephes_log_p8);
+ y = vsl.MulAdd(tmp0, cephes_log_p0, cephes_log_p1);
+ y1 = vsl.MulAdd(tmp0, cephes_log_p3, cephes_log_p4);
+ y2 = vsl.MulAdd(tmp0, cephes_log_p6, cephes_log_p7);
+ y = vsl.MulAdd(y, tmp0, cephes_log_p2);
+ y1 = vsl.MulAdd(y1, tmp0, cephes_log_p5);
+ y2 = vsl.MulAdd(y2, tmp0, cephes_log_p8);
y = vsl.MulAdd(y, x3, y1);
y = vsl.MulAdd(y, x3, y2);
y = vsl.Mul(y, x3);
y1 = vsl.Mul(cephes_log_q1, e);
- tmp = vsl.Mul(half, x2);
+ llvm::Value* tmp2 = vsl.Mul(half, x2);
y = vsl.Add(y, y1);
- input = vsl.Sub(input, tmp);
+ tmp0 = vsl.Sub(tmp0, tmp2);
y2 = vsl.Mul(cephes_log_q2, e);
- input = vsl.Add(input, y);
- input = vsl.Add(input, y2);
+ tmp0 = vsl.Add(tmp0, y);
+ tmp0 = vsl.Add(tmp0, y2);
- // Negative arg will be NAN, 0 will be -INF.
- llvm::Value* or_lhs =
- vsl.FloatAndNot(iszero_mask, vsl.FloatOr(input, invalid_mask));
- llvm::Value* or_rhs = vsl.FloatAnd(iszero_mask, minus_inf);
- llvm::Value* result = vsl.FloatOr(or_lhs, or_rhs);
+ // Contains +/-inf where +/-inf is the correct answer, otherwise 0.
+ llvm::Value* result_inf = vsl.FloatOr(vsl.FloatAnd(is_zero_mask, minus_inf),
+ vsl.FloatAnd(is_pos_inf_mask, pos_inf));
- b.CreateRet(result);
+ // Contains a finite result or nan. This is the correct answer only if both
+ // result_minus_inf and result_pos_inf are both 0.
+ //
+ // (This implementation works because 0xffffffff is a nan.)
+ llvm::Value* result_finite_or_nan = vsl.FloatOr(tmp0, invalid_mask);
- DCHECK(!llvm::verifyFunction(*vector_log_function));
- return vector_log_function;
+ // Combine the above into a final result.
+ return vsl.FloatOr(result_inf,
+ vsl.FloatAndNot(vsl.FloatOr(is_zero_mask, is_pos_inf_mask),
+ result_finite_or_nan));
}
} // namespace
void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math) {
- auto* tanh_v4f32 =
- EmitVectorF32TanhIfNeeded(module, kTanhV4F32SymbolName,
- /*vector_width=*/4, enable_fast_math);
- auto* tanh_v8f32 =
- EmitVectorF32TanhIfNeeded(module, kTanhV8F32SymbolName,
- /*vector_width=*/8, enable_fast_math);
+ // Curry some params to RewriteCalls.
+ auto rewrite_calls =
+ std::bind(RewriteCalls, module, std::placeholders::_1,
+ std::placeholders::_2, std::placeholders::_3, enable_fast_math);
- auto* exp_v4f32 =
- EmitVectorF32ExpIfNeeded(module, kExpV4F32SymbolName,
- /*vector_width=*/4, enable_fast_math);
- auto* exp_v8f32 =
- EmitVectorF32ExpIfNeeded(module, kExpV8F32SymbolName,
- /*vector_width=*/8, enable_fast_math);
+ rewrite_calls("tanhf", GenerateVF32Tanh, /*vector_width=*/1);
+ rewrite_calls("llvm.tanh.f32", GenerateVF32Tanh, /*vector_width=*/1);
+ rewrite_calls(kTanhV4F32SymbolName, GenerateVF32Tanh, /*vector_width=*/4);
+ rewrite_calls(kTanhV8F32SymbolName, GenerateVF32Tanh, /*vector_width=*/8);
- auto* log_v4f32 =
- EmitVectorF32LogIfNeeded(module, kLogV4F32SymbolName,
- /*vector_width=*/4, enable_fast_math);
- auto* log_v8f32 =
- EmitVectorF32LogIfNeeded(module, kLogV8F32SymbolName,
- /*vector_width=*/8, enable_fast_math);
+ rewrite_calls("expf", GenerateVF32Exp, /*vector_width=*/1);
+ rewrite_calls("llvm.exp.f32", GenerateVF32Exp, /*vector_width=*/1);
+ rewrite_calls(kExpV4F32SymbolName, GenerateVF32Exp, /*vector_width=*/4);
+ rewrite_calls(kExpV8F32SymbolName, GenerateVF32Exp, /*vector_width=*/8);
- // Gather all the call sites, force inline them and then delete the vector
- // function bodies.
- //
- // TODO(b/73081976): Should we avoid inlining these intrinsics in some cases?
-
- std::vector<llvm::CallInst*> calls_to_inline;
- for (auto* function :
- {tanh_v4f32, tanh_v8f32, exp_v4f32, exp_v8f32, log_v4f32, log_v8f32}) {
- if (function != nullptr) {
- for (auto* user : function->users()) {
- calls_to_inline.push_back(llvm::cast<llvm::CallInst>(user));
- }
- }
- }
-
- for (auto* call_to_inline : calls_to_inline) {
- llvm::InlineFunctionInfo inline_function_info;
- CHECK(llvm::InlineFunction(call_to_inline, inline_function_info));
- }
-
- for (auto* function :
- {tanh_v4f32, tanh_v8f32, exp_v4f32, exp_v8f32, log_v4f32, log_v8f32}) {
- if (function != nullptr) {
- function->eraseFromParent();
- }
- }
+ rewrite_calls("logf", GenerateVF32Log, /*vector_width=*/1);
+ rewrite_calls("llvm.log.f32", GenerateVF32Log, /*vector_width=*/1);
+ rewrite_calls(kLogV4F32SymbolName, GenerateVF32Log, /*vector_width=*/4);
+ rewrite_calls(kLogV8F32SymbolName, GenerateVF32Log, /*vector_width=*/8);
}
} // namespace runtime
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc
index 3934c03..762ee67 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_key_value_sort_test.cc
@@ -26,10 +26,16 @@
const string hlo_text = R"(
HloModule KeyValueSort
+compare {
+ p.0.lhs = f32[] parameter(0)
+ p.0.rhs = f32[] parameter(1)
+ ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
+}
+
ENTRY main {
a = f32[10] parameter(0)
- ROOT result = f32[10] sort(f32[10] a), dimensions={0}
+ ROOT result = f32[10] sort(f32[10] a), dimensions={0}, to_apply=compare
}
)";
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
index a7702c2..030bd41 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
@@ -75,8 +75,9 @@
// the buffers in the HLO module. We'll inspect these loads to ensure that
// they have the expected alias information.
llvm::Module ir_module("test", context);
- llvm::Function* func = llvm::cast<llvm::Function>(
- ir_module.getOrInsertFunction("test_fn", llvm::Type::getVoidTy(context)));
+ llvm::Function* func = llvm::dyn_cast<llvm::Function>(
+ ir_module.getOrInsertFunction("test_fn", llvm::Type::getVoidTy(context))
+ .getCallee());
llvm::BasicBlock* bb = llvm::BasicBlock::Create(context, "body", func);
llvm::IRBuilder<> b(bb);
auto* zero = llvm::ConstantInt::get(llvm::Type::getInt32Ty(context), 0);
diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.h b/tensorflow/compiler/xla/service/cpu/vector_support_library.h
index 5690d2b..c444fd7 100644
--- a/tensorflow/compiler/xla/service/cpu/vector_support_library.h
+++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.h
@@ -114,6 +114,9 @@
// raison d'etre) less cluttered.
llvm::Value* FCmpEQMask(llvm::Value* lhs, llvm::Value* rhs);
+ llvm::Value* FCmpEQMask(llvm::Value* lhs, const llvm::APFloat& rhs) {
+ return FCmpEQMask(lhs, GetConstantFloat(lhs->getType(), rhs));
+ }
llvm::Value* FCmpULEMask(llvm::Value* lhs, llvm::Value* rhs);
llvm::Value* FCmpOLTMask(llvm::Value* lhs, llvm::Value* rhs);
llvm::Value* FCmpOLTMask(llvm::Value* lhs, const llvm::APFloat& rhs) {
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 727e0bf..e868dc6 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -440,14 +440,16 @@
{operand_value},
{operand_value->getType()}, b_);
case HloOpcode::kSign: {
- // TODO(b/32151903): Ensure consistent sign behavior for -0.0.
auto type = operand_value->getType();
auto zero = llvm::ConstantFP::get(type, 0.0);
- auto oeq = FCmpOEQ(operand_value, zero);
- auto olt = FCmpOLT(operand_value, zero);
- return Select(oeq, zero,
- Select(olt, llvm::ConstantFP::get(type, -1.0),
- llvm::ConstantFP::get(type, 1.0)));
+ auto ne0_i1 = FCmpONE(operand_value, zero);
+ auto ne0_float = UIToFP(ne0_i1, type);
+ llvm::Value* result = llvm_ir::EmitCallToIntrinsic(
+ llvm::Intrinsic::copysign, {ne0_float, operand_value},
+ {operand_value->getType()}, b_);
+ auto is_nan = FCmpUNO(operand_value, operand_value);
+ result = Select(is_nan, operand_value, result);
+ return result;
}
case HloOpcode::kIsFinite: {
// abs(x) o!= inf, this works because the comparison returns false if
@@ -855,6 +857,9 @@
return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_);
}
+// TODO(b/123355973): We have an implementation of erfinv in math.cc. We
+// shouldn't have two implementations, especially since this one isn't testable
+// (it's only observable via a normally-distributed RNG).
StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
llvm::Value* x) {
if (prim_type != F16 && prim_type != F32 && prim_type != F64) {
@@ -1767,18 +1772,10 @@
auto index_typed_const = [&](uint64 c) -> llvm::Constant* {
return llvm::ConstantInt::get(index_type, c);
};
- // TODO(b/118437727): Remove the R1 path.
- llvm::Value* start_index_value;
- if (hlo->operand(1)->shape().rank() == 1) {
- llvm_ir::IrArray::Index dim_index(1, index_typed_const(i));
- TF_ASSIGN_OR_RETURN(start_index_value,
- operand_to_generator.at(hlo->operand(1))(dim_index));
- } else {
- llvm_ir::IrArray::Index zero_index(index_type);
- TF_ASSIGN_OR_RETURN(
- start_index_value,
- operand_to_generator.at(hlo->operand(1 + i))(zero_index));
- }
+ llvm_ir::IrArray::Index zero_index(index_type);
+ TF_ASSIGN_OR_RETURN(
+ llvm::Value * start_index_value,
+ operand_to_generator.at(hlo->operand(1 + i))(zero_index));
// Clamp the start index so that the sliced portion fits in the operand:
// start_index = clamp(start_index, 0, operand_dim_size - output_dim_size)
@@ -1924,18 +1921,10 @@
return llvm::ConstantInt::get(index_type, c);
};
- llvm::Value* start_index_value;
- // TODO(b/118437727): Remove the R1 path.
- if (hlo->operand(2)->shape().rank() == 1) {
- llvm_ir::IrArray::Index dim_index(1, index_typed_const(i));
- TF_ASSIGN_OR_RETURN(start_index_value,
- operand_to_generator.at(hlo->operand(2))(dim_index));
- } else {
- llvm_ir::IrArray::Index zero_index(index_type);
- TF_ASSIGN_OR_RETURN(
- start_index_value,
- operand_to_generator.at(hlo->operand(2 + i))(zero_index));
- }
+ llvm_ir::IrArray::Index zero_index(index_type);
+ TF_ASSIGN_OR_RETURN(
+ llvm::Value * start_index_value,
+ operand_to_generator.at(hlo->operand(2 + i))(zero_index));
// Clamp the start index so that the update region fits in the operand.
// start_index = clamp(start_index, 0, input_dim_size - update_dim_size)
diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc
index 10b8c01..1518d83 100644
--- a/tensorflow/compiler/xla/service/executable.cc
+++ b/tensorflow/compiler/xla/service/executable.cc
@@ -26,7 +26,6 @@
#include "tensorflow/core/lib/strings/proto_serialization.h"
#include "tensorflow/core/platform/env.h"
-
namespace xla {
StatusOr<std::vector<ScopedShapedBuffer>> Executable::ExecuteOnStreams(
@@ -173,11 +172,13 @@
}
filename = SanitizeFileName(std::move(filename));
string file_path = tensorflow::io::JoinPath(directory_path, filename);
- string result;
- TF_RET_CHECK(
- tensorflow::SerializeToStringDeterministic(hlo_session, &result));
- return tensorflow::WriteStringToFile(tensorflow::Env::Default(), file_path,
- result);
+ const size_t size = hlo_session.ByteSizeLong();
+ auto serialized = absl::make_unique<char[]>(size);
+ TF_RET_CHECK(tensorflow::SerializeToBufferDeterministic(
+ hlo_session, serialized.get(), size));
+ return tensorflow::WriteStringToFile(
+ tensorflow::Env::Default(), file_path,
+ absl::string_view(serialized.get(), size));
}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
index 7d450f4..cb43c27 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
@@ -26,7 +26,6 @@
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
index 2ab754a..ffd4214 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
@@ -308,9 +308,11 @@
false); // No variadic arguments.
// Declares the callee if it is not declared already.
- llvm::Function* callee = llvm::cast<llvm::Function>(
- b_->GetInsertBlock()->getModule()->getOrInsertFunction(
- llvm_ir::AsStringRef(callee_name), callee_type));
+ llvm::Function* callee = llvm::dyn_cast<llvm::Function>(
+ b_->GetInsertBlock()
+ ->getModule()
+ ->getOrInsertFunction(llvm_ir::AsStringRef(callee_name), callee_type)
+ .getCallee());
for (auto attribute : attributes) {
callee->addFnAttr(attribute);
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc
index 4268fb2..4765f67 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc
@@ -17,7 +17,6 @@
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
index 29756d2..391029e 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
@@ -368,12 +368,21 @@
TEST_F(LayoutAssignmentTest, SortLayout) {
const char* hlo_text = R"(
HloModule SortLayout
+
+ compare {
+ p.0.lhs = f32[] parameter(0)
+ p.0.rhs = f32[] parameter(1)
+ p.1.lhs = f32[] parameter(2)
+ p.1.rhs = f32[] parameter(3)
+ ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
+ }
+
ENTRY sort {
keys = f32[3,2]{0,1} constant({{0,1},{0,1},{0,1}})
values = f32[2,3]{1,0} parameter(0)
transpose = f32[3,2]{1,0} transpose(values), dimensions={1,0}
ROOT sort = (f32[3,2]{1,0}, f32[3,2]{1,0}) sort(keys, transpose),
- dimensions={1}
+ dimensions={1}, to_apply=compare
})";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
index 8c6a691..e593f53 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc
@@ -30,7 +30,6 @@
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/logging.h"
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index 82bdd67..3ed6553 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -20,7 +20,6 @@
#include "llvm/IR/Module.h"
#include "tensorflow/compiler/xla/layout_util.h"
-#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 294a454..4fab959 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -38,7 +38,6 @@
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
-#include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
#include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index 1f4f176..4bead3e 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -118,6 +118,9 @@
const HloModuleConfig& hlo_module_config) {
std::vector<string> potential_cuda_roots = tensorflow::CandidateCudaRoots();
+ // "." is our last resort, even though it probably won't work.
+ potential_cuda_roots.push_back(".");
+
// CUDA location explicitly specified by user via --xla_gpu_cuda_data_dir has
// highest priority.
string xla_gpu_cuda_data_dir =
@@ -129,9 +132,23 @@
return potential_cuda_roots;
}
+void PrintCantFindCudaMessage(absl::string_view msg,
+ const HloModuleConfig& hlo_module_config) {
+ LOG(WARNING) << msg;
+ LOG(WARNING) << "Searched in the following directories:";
+ for (const auto& dir : GetCudaRootCandidates(hlo_module_config)) {
+ LOG(WARNING) << " " << dir;
+ }
+ LOG(WARNING)
+ << "You can choose the search directory by setting xla_gpu_cuda_data_dir "
+ "in HloModule's DebugOptions. For most apps, setting the environment "
+ "variable XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda will work.";
+}
+
// Returns the directory containing nvvm libdevice files.
string GetLibdeviceDir(const HloModuleConfig& hlo_module_config) {
- for (const string& cuda_root : GetCudaRootCandidates(hlo_module_config)) {
+ const auto& candidate_dirs = GetCudaRootCandidates(hlo_module_config);
+ for (const string& cuda_root : candidate_dirs) {
string libdevice_dir =
tensorflow::io::JoinPath(cuda_root, "nvvm", "libdevice");
VLOG(2) << "Looking for libdevice at " << libdevice_dir;
@@ -140,8 +157,14 @@
return libdevice_dir;
}
}
- LOG(WARNING) << "Unable to find libdevice dir. Using '.'";
- // Last resort: maybe in the current folder.
+ PrintCantFindCudaMessage(
+ "Can't find directory containing CUDA libevice. This may result in "
+ "compilation or runtime failures, if the program we try to run uses "
+ "routines from libdevice.",
+ hlo_module_config);
+
+ // GetCudaRotCandidates always inclues ".", but but if everything fails, we
+ // return it anyway. Better than returning the empty string.
return ".";
}
@@ -843,10 +866,11 @@
log_warning = !warning_done.exchange(true);
}
if (log_warning) {
- LOG(WARNING)
- << "Failed to compile ptx to cubin. Will attempt to let "
- "GPU driver compile the ptx. "
- << maybe_cubin.status();
+ PrintCantFindCudaMessage(
+ "Can't find ptxas binary. Will back to the GPU driver "
+ "for PTX -> sass compilation. This is OK so long as you don't "
+ "see a warning below about an out-of-date driver version.",
+ hlo_module_config);
}
// We're going to use the driver to JIT our PTX->SASS, so warn if
diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc
index 8154d75..cb01264 100644
--- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc
@@ -25,7 +25,6 @@
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
namespace gpu {
diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc
index a1ed849..d33e9cf 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_copy_test.cc
@@ -24,7 +24,6 @@
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/test.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc
index f91a22d..06b06a5 100644
--- a/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/tests/infeed_test.cc
@@ -25,7 +25,6 @@
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/math/math_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/compiler/xla/service/gpu/variadic_op_splitter.cc b/tensorflow/compiler/xla/service/gpu/variadic_op_splitter.cc
index c552c29..bbbcc2d 100644
--- a/tensorflow/compiler/xla/service/gpu/variadic_op_splitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/variadic_op_splitter.cc
@@ -23,7 +23,6 @@
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.cc b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
index d56f673..5d3e11f 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.cc
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.cc
@@ -268,6 +268,41 @@
select_shape, HloOpcode::kSelect, pred, on_true, on_false));
}
+StatusOr<HloInstruction*> MakeSortHlo(
+ const Shape& sort_shape, absl::Span<HloInstruction* const> operands,
+ int64 dimension_to_sort, HloComputation::Builder* builder,
+ HloModule* module) {
+ CHECK(!operands.empty()) << "Sort Hlo requires at least one operand.";
+ HloComputation* compare_computation;
+ {
+ auto b = HloComputation::Builder("Sort.Compare");
+ Shape key_scalar_shape =
+ ShapeUtil::MakeShape(operands[0]->shape().element_type(), {});
+ auto lhs = b.AddInstruction(
+ HloInstruction::CreateParameter(0, key_scalar_shape, "p.0.lhs"));
+ auto rhs = b.AddInstruction(
+ HloInstruction::CreateParameter(1, key_scalar_shape, "p.0.rhs"));
+ int parameter_count = 2;
+ for (const auto* operand : operands.subspan(1)) {
+ Shape scalar_shape =
+ ShapeUtil::MakeShape(operand->shape().element_type(), {});
+ b.AddInstruction(HloInstruction::CreateParameter(
+ parameter_count, scalar_shape,
+ StrCat("p.", parameter_count / 2, ".lhs")));
+ ++parameter_count;
+ b.AddInstruction(HloInstruction::CreateParameter(
+ parameter_count, scalar_shape,
+ StrCat("p.", parameter_count / 2, ".rhs")));
+ ++parameter_count;
+ }
+ b.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, lhs, rhs));
+ compare_computation = module->AddEmbeddedComputation(b.Build());
+ }
+ return builder->AddInstruction(HloInstruction::CreateSort(
+ sort_shape, dimension_to_sort, operands, compare_computation));
+}
+
StatusOr<HloInstruction*> CollapseFirstNDims(HloInstruction* operand, int64 n) {
CHECK_GT(n, 0);
diff --git a/tensorflow/compiler/xla/service/hlo_creation_utils.h b/tensorflow/compiler/xla/service/hlo_creation_utils.h
index 1c3174e..80f58f6 100644
--- a/tensorflow/compiler/xla/service/hlo_creation_utils.h
+++ b/tensorflow/compiler/xla/service/hlo_creation_utils.h
@@ -123,6 +123,18 @@
HloInstruction* on_true,
HloInstruction* on_false);
+// Creates a Sort HLO instruction and adds it to the computation containing the
+// operands. All operands must be in the same computation. Also creates a
+// default compare sub-computation which sorts the first operand into ascending
+// order.
+// Note that this default compare sub-computation does not have special handling
+// for floating point values and thus can result in undefined behavior in the
+// presence of NaN values.
+StatusOr<HloInstruction*> MakeSortHlo(
+ const Shape& sort_shape, absl::Span<HloInstruction* const> operands,
+ int64 dimension_to_sort, HloComputation::Builder* builder,
+ HloModule* module);
+
// Creates an R1 Constant HLO instruction of the given PrimitiveType with the
// given values and adds it to the given computation.
template <typename NativeT>
diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc
index e602107..849cac2 100644
--- a/tensorflow/compiler/xla/service/hlo_cse.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse.cc
@@ -33,7 +33,6 @@
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/hash/hash.h"
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index 4a7c496..e3059e0 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -17,6 +17,7 @@
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -2356,14 +2357,16 @@
TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) {
auto builder = HloComputation::Builder(TestName());
+ module_ = CreateNewVerifiedModule();
Shape keys_shape = ShapeUtil::MakeShape(F32, {8});
auto keys = builder.AddInstruction(
HloInstruction::CreateParameter(0, keys_shape, "keys"));
- auto sort =
- builder.AddInstruction(HloInstruction::CreateSort(keys_shape, 0, keys));
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto* sort, MakeSortHlo(keys_shape, {keys}, -1, &builder, module_.get()));
- BuildModuleAndRunAnalysis(builder.Build());
+ computation_ = module_->AddEntryComputation(builder.Build());
+ RunAnalysis();
EXPECT_TRUE(
dataflow_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {}));
@@ -2371,6 +2374,7 @@
TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) {
auto builder = HloComputation::Builder(TestName());
+ module_ = CreateNewVerifiedModule();
Shape keys_shape = ShapeUtil::MakeShape(F32, {8});
Shape values_shape = ShapeUtil::MakeShape(F32, {8});
@@ -2378,11 +2382,13 @@
HloInstruction::CreateParameter(0, keys_shape, "keys"));
auto values = builder.AddInstruction(
HloInstruction::CreateParameter(1, values_shape, "values"));
- auto sort = builder.AddInstruction(HloInstruction::CreateSort(
- ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys,
- {values}));
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto* sort,
+ MakeSortHlo(ShapeUtil::MakeTupleShape({keys_shape, values_shape}),
+ {keys, values}, 0, &builder, module_.get()));
- BuildModuleAndRunAnalysis(builder.Build());
+ computation_ = module_->AddEntryComputation(builder.Build());
+ RunAnalysis();
// The buffer for the keys can be shared with the first tuple entry.
EXPECT_TRUE(
diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc
index 9b0f2b2..bff4677 100644
--- a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc
+++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc
@@ -127,6 +127,7 @@
// These are ops where it does not make sense to convert them.
if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant ||
opcode == HloOpcode::kTuple || opcode == HloOpcode::kConvert ||
+ opcode == HloOpcode::kBitcastConvert ||
opcode == HloOpcode::kGetTupleElement ||
opcode == HloOpcode::kInfeed || opcode == HloOpcode::kOutfeed) {
continue;
@@ -148,7 +149,11 @@
opcode == HloOpcode::kConditional) {
continue;
}
- TF_RET_CHECK(hlo->called_computations().empty()) << hlo->ToString();
+ // TODO(b/122298745): Once we don't ignore called computations anymore,
+ // add kSort to the if statement above.
+ if (opcode != HloOpcode::kSort) {
+ TF_RET_CHECK(hlo->called_computations().empty()) << hlo->ToString();
+ }
bool nullary = hlo->operands().empty();
bool wrong_element_type = hlo->shape().element_type() == eliminate_type_;
diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc
index 5b63378..4171f73 100644
--- a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc
@@ -176,5 +176,19 @@
EXPECT_THAT(rng1->control_predecessors(), ElementsAre(rng0));
}
+TEST_F(HloElementTypeConverterTest, BitcastConvertIsUnmodified) {
+ const string& hlo_string = R"(
+ HloModule test
+
+ ENTRY test {
+ p = bf16[] parameter(0)
+ ROOT c = u16[] bitcast-convert(p)
+ })";
+ auto module = ParseAndReturnVerifiedModule(hlo_string).ValueOrDie();
+ HloElementTypeConverter converter(BF16, F32);
+ TF_ASSERT_OK_AND_ASSIGN(bool converted, RunHloPass(&converter, module.get()));
+ EXPECT_FALSE(converted);
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index 29e4143..56a1b6f 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -1361,46 +1361,81 @@
}
namespace {
+template <typename NativeT>
+Literal ExtractLiteralFromIndexPositions(const Literal& from,
+ absl::Span<int64 const> indices,
+ bool extract_as_scalar) {
+ if (extract_as_scalar) {
+ return LiteralUtil::CreateR0<NativeT>(from.Get<NativeT>({indices[0]}));
+ }
+ // We use a InlinedVector here because we need to convert it to an
+ // absl::Span later, and this would not work with std::vector<bool>.
+ absl::InlinedVector<NativeT, 10> values;
+ for (int64 index : indices) {
+ values.push_back(from.Get<NativeT>({index}));
+ }
+ return LiteralUtil::CreateR1<NativeT>(values);
+}
+
StatusOr<Literal> ExtractFromIndexPositions(const Literal& from,
- absl::Span<int64 const> indices) {
+ absl::Span<int64 const> indices,
+ bool extract_as_scalar = false) {
+ if (extract_as_scalar) {
+ CHECK_EQ(indices.size(), 1);
+ }
PrimitiveType type = from.shape().element_type();
switch (type) {
case PRED: {
- // We use a InlinedVector here because we need to convert it to an
- // absl::Span later, and this would not work with std::vector<bool>.
- absl::InlinedVector<bool, 10> values;
- for (int64 index : indices) {
- values.push_back(from.Get<bool>({index}));
- }
- return LiteralUtil::CreateR1<bool>(values);
+ return ExtractLiteralFromIndexPositions<bool>(from, indices,
+ extract_as_scalar);
}
- case F32: {
- std::vector<float> values;
- for (int64 index : indices) {
- values.push_back(from.Get<float>({index}));
- }
- return LiteralUtil::CreateR1<float>(values);
+ case U8: {
+ return ExtractLiteralFromIndexPositions<uint8>(from, indices,
+ extract_as_scalar);
}
- case U32: {
- std::vector<uint32> values;
- for (int64 index : indices) {
- values.push_back(from.Get<uint32>({index}));
- }
- return LiteralUtil::CreateR1<uint32>(values);
- }
- case S32: {
- std::vector<int32> values;
- for (int64 index : indices) {
- values.push_back(from.Get<int32>({index}));
- }
- return LiteralUtil::CreateR1<int32>(values);
+ case S8: {
+ return ExtractLiteralFromIndexPositions<int8>(from, indices,
+ extract_as_scalar);
}
case BF16: {
- std::vector<bfloat16> values;
- for (int64 index : indices) {
- values.push_back(from.Get<bfloat16>({index}));
- }
- return LiteralUtil::CreateR1<bfloat16>(values);
+ return ExtractLiteralFromIndexPositions<bfloat16>(from, indices,
+ extract_as_scalar);
+ }
+ case F16: {
+ return ExtractLiteralFromIndexPositions<Eigen::half>(from, indices,
+ extract_as_scalar);
+ }
+ case U16: {
+ return ExtractLiteralFromIndexPositions<uint16>(from, indices,
+ extract_as_scalar);
+ }
+ case S16: {
+ return ExtractLiteralFromIndexPositions<int16>(from, indices,
+ extract_as_scalar);
+ }
+ case F32: {
+ return ExtractLiteralFromIndexPositions<float>(from, indices,
+ extract_as_scalar);
+ }
+ case U32: {
+ return ExtractLiteralFromIndexPositions<uint32>(from, indices,
+ extract_as_scalar);
+ }
+ case S32: {
+ return ExtractLiteralFromIndexPositions<int32>(from, indices,
+ extract_as_scalar);
+ }
+ case F64: {
+ return ExtractLiteralFromIndexPositions<double>(from, indices,
+ extract_as_scalar);
+ }
+ case U64: {
+ return ExtractLiteralFromIndexPositions<uint64>(from, indices,
+ extract_as_scalar);
+ }
+ case S64: {
+ return ExtractLiteralFromIndexPositions<int64>(from, indices,
+ extract_as_scalar);
}
default:
return InvalidArgument("Unsupported type for Sort: %s",
@@ -1410,108 +1445,151 @@
} // namespace
Status HloEvaluator::HandleSort(HloInstruction* sort) {
- if (!sort->shape().IsTuple()) {
- return DefaultAction(sort);
- } else {
- TF_RET_CHECK(sort->operand_count() >= 2) << "Expected key-value sort";
- for (int64 i = 1; i < sort->operand_count(); ++i) {
- TF_RET_CHECK(ShapeUtil::SameDimensions(sort->operand(0)->shape(),
- sort->operand(i)->shape()))
- << "All Sort operands must have the same dimensions";
- }
+ TF_RET_CHECK(sort->operand_count() >= 1)
+ << "Expected at least 1 operand for sort";
+ for (int64 i = 1; i < sort->operand_count(); ++i) {
+ TF_RET_CHECK(ShapeUtil::SameDimensions(sort->operand(0)->shape(),
+ sort->operand(i)->shape()))
+ << "All Sort operands must have the same dimensions";
+ }
- if (VLOG_IS_ON(3)) {
- for (int64 i = 0; i < sort->operand_count(); ++i) {
- VLOG(3) << "HandleSort operand " << i << " literal: "
- << GetEvaluatedLiteralFor(sort->operand(i)).ToString();
- }
- }
- Shape key_shape = sort->operand(0)->shape();
- auto rank = key_shape.rank();
- PrimitiveType keys_type = key_shape.element_type();
- if (keys_type != F32 && keys_type != U32 && keys_type != S32 &&
- keys_type != BF16) {
- return InvalidArgument("Unsupported type for Sort: %s",
- PrimitiveType_Name(keys_type));
- }
- std::vector<Literal> result_literals;
- result_literals.reserve(sort->operand_count());
+ if (VLOG_IS_ON(3)) {
for (int64 i = 0; i < sort->operand_count(); ++i) {
- result_literals.emplace_back(sort->operand(i)->shape());
+ VLOG(3) << "HandleSort operand " << i << " literal: "
+ << GetEvaluatedLiteralFor(sort->operand(i)).ToString();
}
- std::vector<int64> zero_base(rank, 0);
- std::vector<int64> increment(rank, 1);
- int64 sort_dim = sort->dimensions(0);
- int64 sort_dim_elements = key_shape.dimensions(sort_dim);
- increment[sort_dim] = sort_dim_elements;
- // Iterate through each dimension except 'sort_dim'.
- TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
- key_shape, zero_base, AsInt64Slice(key_shape.dimensions()), increment,
- [&](absl::Span<const int64> indices) -> StatusOr<bool> {
- // Extract a slice from each operand literal that corresponds to
- // exactly the row in dimension 'sort_dim'.
- std::vector<int64> limit_indices(indices.begin(), indices.end());
- absl::c_for_each(limit_indices, [](int64& index) { ++index; });
- limit_indices[sort_dim] = sort_dim_elements;
- std::vector<Literal> literals_to_sort;
- literals_to_sort.reserve(sort->operand_count());
- for (int64 i = 0; i < sort->operand_count(); ++i) {
- TF_ASSIGN_OR_RETURN(auto literal_to_sort,
- GetEvaluatedLiteralFor(sort->operand(i))
- .Slice(indices, limit_indices)
- .Reshape({sort_dim_elements}));
- literals_to_sort.push_back(std::move(literal_to_sort));
- }
- std::vector<int64> indices_to_sort(sort_dim_elements);
- std::iota(indices_to_sort.begin(), indices_to_sort.end(), 0);
- std::stable_sort(
- indices_to_sort.begin(), indices_to_sort.end(),
- [keys_type, &literals_to_sort](int64 a, int64 b) {
- switch (keys_type) {
- case F32: {
- auto key_lhs = literals_to_sort[0].Get<float>({a});
- auto key_rhs = literals_to_sort[0].Get<float>({b});
- return SafeLess(key_lhs, key_rhs);
- }
- case U32: {
- auto key_lhs = literals_to_sort[0].Get<uint32>({a});
- auto key_rhs = literals_to_sort[0].Get<uint32>({b});
- return SafeLess(key_lhs, key_rhs);
- }
- case S32: {
- auto key_lhs = literals_to_sort[0].Get<int32>({a});
- auto key_rhs = literals_to_sort[0].Get<int32>({b});
- return SafeLess(key_lhs, key_rhs);
- }
- case BF16: {
- auto key_lhs = literals_to_sort[0].Get<bfloat16>({a});
- auto key_rhs = literals_to_sort[0].Get<bfloat16>({b});
- return SafeLess(key_lhs, key_rhs);
- }
- default:
- // We should never reach here, because we checked earlier
- // that 'key_type' is one of the cases above.
- LOG(FATAL) << "Invalid key type in Sort: %s",
- PrimitiveType_Name(keys_type);
- return false;
+ }
+ Shape key_shape = sort->operand(0)->shape();
+ auto rank = key_shape.rank();
+ PrimitiveType keys_type = key_shape.element_type();
+ if (keys_type != F64 && keys_type != U64 && keys_type != S64 &&
+ keys_type != F32 && keys_type != U32 && keys_type != S32 &&
+ keys_type != BF16 && keys_type != F16 && keys_type != U16 &&
+ keys_type != S16 && keys_type != U8 && keys_type != S8) {
+ return InvalidArgument("Unsupported type for Sort: %s",
+ PrimitiveType_Name(keys_type));
+ }
+ std::vector<Literal> result_literals;
+ result_literals.reserve(sort->operand_count());
+ for (int64 i = 0; i < sort->operand_count(); ++i) {
+ result_literals.emplace_back(sort->operand(i)->shape());
+ }
+ std::vector<int64> zero_base(rank, 0);
+ std::vector<int64> increment(rank, 1);
+ int64 sort_dim = sort->dimensions(0);
+ int64 sort_dim_elements = key_shape.dimensions(sort_dim);
+ increment[sort_dim] = sort_dim_elements;
+ // Iterate through each dimension except 'sort_dim'.
+ TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
+ key_shape, zero_base, AsInt64Slice(key_shape.dimensions()), increment,
+ [&](absl::Span<const int64> indices) -> StatusOr<bool> {
+ // Extract a slice from each operand literal that corresponds to
+ // exactly the row in dimension 'sort_dim'.
+ std::vector<int64> limit_indices(indices.begin(), indices.end());
+ absl::c_for_each(limit_indices, [](int64& index) { ++index; });
+ limit_indices[sort_dim] = sort_dim_elements;
+ std::vector<Literal> literals_to_sort;
+ literals_to_sort.reserve(sort->operand_count());
+ for (int64 i = 0; i < sort->operand_count(); ++i) {
+ TF_ASSIGN_OR_RETURN(auto literal_to_sort,
+ GetEvaluatedLiteralFor(sort->operand(i))
+ .Slice(indices, limit_indices)
+ .Reshape({sort_dim_elements}));
+ literals_to_sort.push_back(std::move(literal_to_sort));
+ }
+ std::vector<int64> indices_to_sort(sort_dim_elements);
+ std::iota(indices_to_sort.begin(), indices_to_sort.end(), 0);
+ std::stable_sort(
+ indices_to_sort.begin(), indices_to_sort.end(),
+ [keys_type, &literals_to_sort](int64 a, int64 b) {
+ switch (keys_type) {
+ case F64: {
+ auto key_lhs = literals_to_sort[0].Get<double>({a});
+ auto key_rhs = literals_to_sort[0].Get<double>({b});
+ return SafeLess(key_lhs, key_rhs);
}
- });
- std::vector<int64> slice_dimensions(rank, 1);
- slice_dimensions[sort_dim] = sort_dim_elements;
- std::vector<int64> start_indices(rank, 0);
- for (int64 i = 0; i < sort->operand_count(); ++i) {
- TF_ASSIGN_OR_RETURN(Literal sorted_literal,
- ExtractFromIndexPositions(literals_to_sort[i],
- indices_to_sort));
- TF_ASSIGN_OR_RETURN(auto sorted_literal_reshaped,
- sorted_literal.Reshape(slice_dimensions));
- TF_RETURN_IF_ERROR(result_literals[i].CopySliceFrom(
- sorted_literal_reshaped, start_indices, indices,
- slice_dimensions));
- }
- return true;
- }));
+ case U64: {
+ auto key_lhs = literals_to_sort[0].Get<uint64>({a});
+ auto key_rhs = literals_to_sort[0].Get<uint64>({b});
+ return SafeLess(key_lhs, key_rhs);
+ }
+ case S64: {
+ auto key_lhs = literals_to_sort[0].Get<int64>({a});
+ auto key_rhs = literals_to_sort[0].Get<int64>({b});
+ return SafeLess(key_lhs, key_rhs);
+ }
+ case F32: {
+ auto key_lhs = literals_to_sort[0].Get<float>({a});
+ auto key_rhs = literals_to_sort[0].Get<float>({b});
+ return SafeLess(key_lhs, key_rhs);
+ }
+ case U32: {
+ auto key_lhs = literals_to_sort[0].Get<uint32>({a});
+ auto key_rhs = literals_to_sort[0].Get<uint32>({b});
+ return SafeLess(key_lhs, key_rhs);
+ }
+ case S32: {
+ auto key_lhs = literals_to_sort[0].Get<int32>({a});
+ auto key_rhs = literals_to_sort[0].Get<int32>({b});
+ return SafeLess(key_lhs, key_rhs);
+ }
+ case BF16: {
+ auto key_lhs = literals_to_sort[0].Get<bfloat16>({a});
+ auto key_rhs = literals_to_sort[0].Get<bfloat16>({b});
+ return SafeLess(key_lhs, key_rhs);
+ }
+ case F16: {
+ auto key_lhs = literals_to_sort[0].Get<Eigen::half>({a});
+ auto key_rhs = literals_to_sort[0].Get<Eigen::half>({b});
+ return SafeLess(key_lhs, key_rhs);
+ }
+ case U16: {
+ auto key_lhs = literals_to_sort[0].Get<uint16>({a});
+ auto key_rhs = literals_to_sort[0].Get<uint16>({b});
+ return SafeLess(key_lhs, key_rhs);
+ }
+ case S16: {
+ auto key_lhs = literals_to_sort[0].Get<int16>({a});
+ auto key_rhs = literals_to_sort[0].Get<int16>({b});
+ return SafeLess(key_lhs, key_rhs);
+ }
+ case U8: {
+ auto key_lhs = literals_to_sort[0].Get<uint8>({a});
+ auto key_rhs = literals_to_sort[0].Get<uint8>({b});
+ return SafeLess(key_lhs, key_rhs);
+ }
+ case S8: {
+ auto key_lhs = literals_to_sort[0].Get<int8>({a});
+ auto key_rhs = literals_to_sort[0].Get<int8>({b});
+ return SafeLess(key_lhs, key_rhs);
+ }
+ default:
+ // We should never reach here, because we checked earlier
+ // that 'key_type' is one of the cases above.
+ LOG(FATAL) << "Invalid key type in Sort: %s",
+ PrimitiveType_Name(keys_type);
+ return false;
+ }
+ });
+ std::vector<int64> slice_dimensions(rank, 1);
+ slice_dimensions[sort_dim] = sort_dim_elements;
+ std::vector<int64> start_indices(rank, 0);
+ for (int64 i = 0; i < sort->operand_count(); ++i) {
+ TF_ASSIGN_OR_RETURN(
+ Literal sorted_literal,
+ ExtractFromIndexPositions(literals_to_sort[i], indices_to_sort));
+ TF_ASSIGN_OR_RETURN(auto sorted_literal_reshaped,
+ sorted_literal.Reshape(slice_dimensions));
+ TF_RETURN_IF_ERROR(result_literals[i].CopySliceFrom(
+ sorted_literal_reshaped, start_indices, indices,
+ slice_dimensions));
+ }
+ return true;
+ }));
+ if (sort->operand_count() == 1) {
+ evaluated_[sort] = std::move(result_literals[0]);
+ } else {
std::vector<const Literal*> literal_ptrs;
absl::c_transform(result_literals, std::back_inserter(literal_ptrs),
[](const Literal& literal) { return &literal; });
@@ -1520,8 +1598,8 @@
VLOG(3) << "HandleSort result_tuple: " << result_tuple.ToString();
evaluated_[sort] = std::move(result_tuple);
- return Status::OK();
}
+ return Status::OK();
}
Status HloEvaluator::HandleReduce(HloInstruction* reduce) {
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
index 648c7d0..652042e 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_typed_visitor.h
@@ -462,9 +462,9 @@
return HandleNegate<ReturnT>(negate);
}
- template <
- typename NativeT,
- typename std::enable_if<!is_complex_t<NativeT>::value>::type* = nullptr>
+ template <typename NativeT,
+ typename std::enable_if<std::is_integral<NativeT>::value>::type* =
+ nullptr>
Status HandleSign(HloInstruction* sign) {
TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign],
ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) {
@@ -474,6 +474,23 @@
return Status::OK();
}
+ template <typename NativeT,
+ typename std::enable_if<
+ std::is_same<NativeT, bfloat16>::value ||
+ std::is_same<NativeT, Eigen::half>::value ||
+ std::is_floating_point<NativeT>::value>::type* = nullptr>
+ Status HandleSign(HloInstruction* sign) {
+ TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign],
+ ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) {
+ return std::isnan(elem_operand)
+ ? elem_operand
+ : std::copysign(
+ elem_operand != ElementwiseT(0),
+ elem_operand);
+ }));
+ return Status::OK();
+ }
+
template <
typename NativeT,
typename std::enable_if<is_complex_t<NativeT>::value>::type* = nullptr>
@@ -1662,73 +1679,8 @@
return Status::OK();
}
- template <typename NativeT,
- typename std::enable_if<
- !is_complex_t<NativeT>::value &&
- !std::is_same<NativeT, bool>::value>::type* = nullptr>
- Status HandleSort(HloInstruction* sort) {
- auto keys = sort->operand(0);
- TF_RET_CHECK(sort->operand_count() == 1)
- << "Typed visitor does not support key-value sort";
-
- const Literal& keys_literal = parent_->GetEvaluatedLiteralFor(keys);
- int64 sort_dim = sort->dimensions(0);
- int64 sort_dim_elements = keys->shape().dimensions(sort_dim);
- int64 rank = keys->shape().rank();
- if (rank == 0) {
- // Nothing to sort.
- parent_->evaluated_[sort] = keys_literal.Clone();
- return Status::OK();
- }
- Literal result_literal(keys_literal.shape());
- std::vector<int64> zero_base(rank, 0);
- std::vector<int64> increment(rank, 1);
- increment[sort_dim] = sort_dim_elements;
- // Iterate through each dimension except 'sort_dim'.
- TF_RETURN_IF_ERROR(ShapeUtil::ForEachIndexWithStatus(
- keys->shape(), zero_base, AsInt64Slice(keys->shape().dimensions()),
- increment, [&](absl::Span<const int64> indices) -> StatusOr<bool> {
- // Extract a slice from the literal that corresponds to exactly the
- // row in dimension 'sort_dim'.
- std::vector<int64> limit_indices(indices.begin(), indices.end());
- absl::c_for_each(limit_indices, [](int64& index) { ++index; });
- limit_indices[sort_dim] = sort_dim_elements;
- TF_ASSIGN_OR_RETURN(auto row_to_sort,
- keys_literal.Slice(indices, limit_indices)
- .Reshape({sort_dim_elements}));
- const auto& row_data = row_to_sort.data<NativeT>();
-
- std::vector<NativeT> result_data(row_data.begin(), row_data.end());
- std::stable_sort(result_data.begin(), result_data.end(),
- [](const NativeT& a, const NativeT& b) {
- return SafeLess<NativeT>(a, b);
- });
- Literal sorted_row(ShapeUtil::MakeShape(keys->shape().element_type(),
- {sort_dim_elements}));
- sorted_row.PopulateR1(absl::Span<const NativeT>(result_data));
- std::vector<int64> slice_dimensions(rank, 1);
- slice_dimensions[sort_dim] = sort_dim_elements;
- TF_ASSIGN_OR_RETURN(auto sorted_row_reshaped,
- sorted_row.Reshape(slice_dimensions));
- std::vector<int64> start_indices(rank, 0);
- TF_RETURN_IF_ERROR(result_literal.CopySliceFrom(
- sorted_row_reshaped, start_indices, indices, slice_dimensions));
- return true;
- }));
- parent_->evaluated_[sort] = std::move(result_literal);
- return Status::OK();
- }
-
- template <typename NativeT,
- typename std::enable_if<is_complex_t<NativeT>::value ||
- std::is_same<NativeT, bool>::value>::type* =
- nullptr>
- Status HandleSort(HloInstruction* sort) {
- return UnsupportedTypeError(sort);
- }
-
Status HandleSort(HloInstruction* sort) override {
- return HandleSort<ReturnT>(sort);
+ return UnsupportedTypeError(sort);
}
Status HandleReduce(HloInstruction* hlo) override {
diff --git a/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc b/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc
index a46a107..265bfdf 100644
--- a/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_input_output_alias_config_test.cc
@@ -29,7 +29,6 @@
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 1b677bc..fd9b8ea 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -203,9 +203,14 @@
<< "Sort instruction should have 1 dimension";
auto sort_operands = all_operands();
HloInstruction* keys = sort_operands[0];
- instruction = CreateSort(
- shape, proto.dimensions(0), keys,
- absl::Span<HloInstruction* const>(sort_operands).subspan(1));
+ if (proto.called_computation_ids_size() == 1) {
+ instruction = CreateSort(shape, proto.dimensions(0), all_operands(),
+ computations(0));
+ } else {
+ instruction = CreateSort(
+ shape, proto.dimensions(0), keys,
+ absl::Span<HloInstruction* const>(sort_operands).subspan(1));
+ }
break;
}
case HloOpcode::kTranspose:
@@ -1158,6 +1163,13 @@
return absl::make_unique<HloSortInstruction>(shape, dimension, keys, values);
}
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSort(
+ const Shape& shape, int64 dimension,
+ absl::Span<HloInstruction* const> operands, HloComputation* compare) {
+ return absl::make_unique<HloSortInstruction>(shape, dimension, operands,
+ compare);
+}
+
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion(
const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) {
return absl::make_unique<HloFusionInstruction>(shape, fusion_kind,
@@ -1952,6 +1964,7 @@
case HloOpcode::kReduce:
case HloOpcode::kAllReduce:
case HloOpcode::kScatter:
+ case HloOpcode::kSort:
CHECK_EQ(called_computations_.size(), 1);
return called_computations_[0];
default:
@@ -1971,6 +1984,7 @@
case HloOpcode::kReduce:
case HloOpcode::kAllReduce:
case HloOpcode::kScatter:
+ case HloOpcode::kSort:
CHECK_EQ(called_computations_.size(), 1);
called_computations_[0] = computation;
break;
@@ -2243,9 +2257,14 @@
opcode() == HloOpcode::kReduceWindow ||
opcode() == HloOpcode::kReduce ||
opcode() == HloOpcode::kAllReduce ||
- opcode() == HloOpcode::kScatter) {
- extra.push_back(
- StrCat("to_apply=", PrintName(to_apply()->name(), options)));
+ opcode() == HloOpcode::kScatter ||
+ opcode() == HloOpcode::kSort) {
+ // TODO(b/122298745): Remove this check when Sort has a required
+ // sub-computation.
+ if (!called_computations().empty()) {
+ extra.push_back(
+ StrCat("to_apply=", PrintName(to_apply()->name(), options)));
+ }
} else if (!called_computations().empty()) {
extra.push_back(StrCat(
"calls=",
@@ -2280,8 +2299,13 @@
case HloOpcode::kReduce:
case HloOpcode::kAllReduce:
case HloOpcode::kScatter:
- extra.push_back(
- StrCat("to_apply=\n", to_apply()->ToString(new_options)));
+ case HloOpcode::kSort:
+ // TODO(b/122298745): Remove this check once sort has a required
+ // sub-computation.
+ if (to_apply() != nullptr) {
+ extra.push_back(
+ StrCat("to_apply=\n", to_apply()->ToString(new_options)));
+ }
break;
default:
if (!called_computations().empty()) {
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index c11d29d..79c6c85 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -676,6 +676,15 @@
const Shape& shape, int64 dimension, HloInstruction* keys,
absl::Span<HloInstruction* const> values = {});
+ // Creates a n-ary sort op with a 'compare' computation which is used for
+ // comparisons in the sorting algorithm. 'compare' gets 2 * n parameters,
+ // where parameters 2 * i and 2 * i + 1 are the values of the i-th operand at
+ // specific index positions which should be compared, and should return a
+ // PRED.
+ static std::unique_ptr<HloInstruction> CreateSort(
+ const Shape& shape, int64 dimension,
+ absl::Span<HloInstruction* const> operands, HloComputation* compare);
+
// Creates a while instruction, given a condition computation, a body
// computation, and the initial value for the input of the computations. For
// example, shape: S32, condition: i -> i < 1000, body: i -> i * 2, init: 1
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index 785206b..8e40e54 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -619,6 +619,16 @@
}
}
+HloSortInstruction::HloSortInstruction(
+ const Shape& shape, int64 dimension,
+ absl::Span<HloInstruction* const> operands, HloComputation* compare)
+ : HloInstruction(HloOpcode::kSort, shape), dimensions_({dimension}) {
+ for (auto* value : operands) {
+ AppendOperand(value);
+ }
+ AppendComputation(compare);
+}
+
HloInstructionProto HloSortInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
for (int64 dimension : dimensions_) {
@@ -637,12 +647,25 @@
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const {
const auto& casted_other = static_cast<const HloSortInstruction&>(other);
- return dimensions() == casted_other.dimensions();
+ if (dimensions() != casted_other.dimensions()) {
+ return false;
+ }
+ if (called_computations().empty()) {
+ return other.called_computations().empty();
+ }
+ if (other.called_computations().empty()) {
+ return false;
+ }
+ return eq_computations(to_apply(), other.to_apply());
}
std::unique_ptr<HloInstruction> HloSortInstruction::CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
+ if (!called_computations().empty()) {
+ return absl::make_unique<HloSortInstruction>(shape, dimensions(0),
+ new_operands, to_apply());
+ }
HloInstruction* keys = new_operands[0];
return absl::make_unique<HloSortInstruction>(shape, dimensions(0), keys,
new_operands.subspan(1));
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index 1b4a947..6b6157a 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -420,6 +420,9 @@
explicit HloSortInstruction(const Shape& shape, int64 dimension,
HloInstruction* keys,
absl::Span<HloInstruction* const> values = {});
+ explicit HloSortInstruction(const Shape& shape, int64 dimension,
+ absl::Span<HloInstruction* const> operands,
+ HloComputation* compare);
// Returns the dimension sizes or numbers associated with this instruction.
const std::vector<int64>& dimensions() const override { return dimensions_; }
int64 dimensions(int64 index) const override { return dimensions()[index]; }
diff --git a/tensorflow/compiler/xla/service/hlo_lexer.cc b/tensorflow/compiler/xla/service/hlo_lexer.cc
index 7987608..c1a642d 100644
--- a/tensorflow/compiler/xla/service/hlo_lexer.cc
+++ b/tensorflow/compiler/xla/service/hlo_lexer.cc
@@ -212,6 +212,15 @@
// A lone '/' is an error.
return TokKind::kError;
}
+ case '.':
+ if (PeekCurrentChar() == '.') {
+ current_ptr_++;
+ if (PeekCurrentChar() == '.') {
+ current_ptr_++;
+ return TokKind::kDots;
+ }
+ }
+ return TokKind::kError;
case '"':
return LexString();
}
@@ -513,6 +522,8 @@
return "kInt";
case TokKind::kDecimal:
return "kDecimal";
+ case TokKind::kDots:
+ return "kDots";
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_lexer.h b/tensorflow/compiler/xla/service/hlo_lexer.h
index 94fac3c..16eed21 100644
--- a/tensorflow/compiler/xla/service/hlo_lexer.h
+++ b/tensorflow/compiler/xla/service/hlo_lexer.h
@@ -47,6 +47,7 @@
kRbrace, // { }
kLparen,
kRparen, // ( )
+ kDots, // ...
kArrow, // ->
kLeq, // <=
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index d7fa6b4..b3b24c1 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -28,6 +28,7 @@
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
+#include "tensorflow/compiler/xla/service/hlo_lexer.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
@@ -852,13 +853,10 @@
break;
}
case HloOpcode::kReplicaId: {
- if (!ParseOperands(&operands, /*expected_size=*/1) ||
+ if (!ParseOperands(&operands, /*expected_size=*/0) ||
!ParseAttributes(attrs)) {
return false;
}
- if (!operands.empty()) {
- return false;
- }
instruction = builder->AddInstruction(HloInstruction::CreateReplicaId());
break;
}
@@ -896,14 +894,23 @@
optional<std::vector<tensorflow::int64>> dimensions;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
&dimensions};
+ optional<HloComputation*> to_apply;
+ // TODO(b/122298745): Make this required.
+ attrs["to_apply"] = {/*required=*/false, AttrTy::kHloComputation,
+ &to_apply};
if (!ParseOperands(&operands) || !ParseAttributes(attrs) ||
dimensions->size() != 1) {
return false;
}
- instruction = builder->AddInstruction(HloInstruction::CreateSort(
- shape, dimensions->at(0),
- /*keys=*/operands[0],
- /*values=*/absl::Span<HloInstruction* const>(operands).subspan(1)));
+ if (to_apply.has_value()) {
+ instruction = builder->AddInstruction(HloInstruction::CreateSort(
+ shape, dimensions->at(0), operands, to_apply.value()));
+ } else {
+ instruction = builder->AddInstruction(HloInstruction::CreateSort(
+ shape, dimensions->at(0),
+ /*keys=*/operands[0],
+ /*values=*/absl::Span<HloInstruction* const>(operands).subspan(1)));
+ }
break;
}
case HloOpcode::kTuple: {
@@ -2150,6 +2157,16 @@
}
break;
}
+ case TokKind::kDots: {
+ if (nest_level != 1) {
+ return TokenError(absl::StrFormat(
+ "expects `...` at nest level 1, but sees it at nest level %d",
+ nest_level));
+ }
+ elems_seen_per_dim[0] = shape.dimensions(0);
+ lexer_.Lex();
+ break;
+ }
case TokKind::kComma:
// Skip.
lexer_.Lex();
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 6eee767..cd7effe 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -1047,9 +1047,15 @@
"SortKey",
R"(HloModule sort
+compare {
+ p.0.lhs = f32[] parameter(0)
+ p.0.rhs = f32[] parameter(1)
+ ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
+}
+
ENTRY Sort {
x = f32[1024]{0} parameter(0)
- ROOT sorted = f32[1024]{0} sort(x), dimensions={0}
+ ROOT sorted = f32[1024]{0} sort(x), dimensions={0}, to_apply=compare
}
)"
@@ -1059,10 +1065,18 @@
"SortKeyValue",
R"(HloModule sort
+compare {
+ p.1.lhs = s32[] parameter(2)
+ p.1.rhs = s32[] parameter(3)
+ p.0.lhs = f32[] parameter(0)
+ p.0.rhs = f32[] parameter(1)
+ ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
+}
+
ENTRY Sort {
keys = f32[1024]{0} parameter(0)
values = s32[1024]{0} parameter(1)
- ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0}
+ ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0}, to_apply=compare
}
)"
@@ -1072,9 +1086,15 @@
"SortKeyR2",
R"(HloModule sort
+compare {
+ p.0.lhs = f32[] parameter(0)
+ p.0.rhs = f32[] parameter(1)
+ ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
+}
+
ENTRY Sort {
x = f32[1024,16]{0,1} parameter(0)
- ROOT sorted = f32[1024,16]{0,1} sort(x), dimensions={0}
+ ROOT sorted = f32[1024,16]{0,1} sort(x), dimensions={0}, to_apply=compare
}
)"
@@ -1084,10 +1104,18 @@
"SortKeyValueR2",
R"(HloModule sort
+compare {
+ p.1.lhs = s32[] parameter(2)
+ p.1.rhs = s32[] parameter(3)
+ p.0.lhs = f32[] parameter(0)
+ p.0.rhs = f32[] parameter(1)
+ ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
+}
+
ENTRY Sort {
keys = f32[1024,16]{0,1} parameter(0)
values = s32[1024,16]{0,1} parameter(1)
- ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}) sort(keys, values), dimensions={0}
+ ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}) sort(keys, values), dimensions={0}, to_apply=compare
}
)"
@@ -1097,12 +1125,24 @@
"SortManyValues",
R"(HloModule sort
+compare {
+ p.1.lhs = s32[] parameter(2)
+ p.1.rhs = s32[] parameter(3)
+ p.2.lhs = u32[] parameter(4)
+ p.2.rhs = u32[] parameter(5)
+ p.3.lhs = f32[] parameter(6)
+ p.3.rhs = f32[] parameter(7)
+ p.0.lhs = f32[] parameter(0)
+ p.0.rhs = f32[] parameter(1)
+ ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
+}
+
ENTRY Sort {
keys = f32[1024,16]{0,1} parameter(0)
values.0 = s32[1024,16]{0,1} parameter(1)
values.1 = u32[1024,16]{0,1} parameter(2)
values.2 = f32[1024,16]{0,1} parameter(3)
- ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}, u32[1024,16]{0,1}, f32[1024,16]{0,1}) sort(keys, values.0, values.1, values.2), dimensions={0}
+ ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}, u32[1024,16]{0,1}, f32[1024,16]{0,1}) sort(keys, values.0, values.1, values.2), dimensions={0}, to_apply=compare
}
)"
@@ -1282,6 +1322,17 @@
)"
},
+// replica-id
+{
+"ReplicaId",
+R"(HloModule replica-id
+
+ENTRY Replica-id {
+ ROOT replica-id = u32[] replica-id()
+}
+
+)"
+},
// Iota
{
"Iota",
@@ -1309,10 +1360,18 @@
"ScheduledModule",
R"(HloModule scheduled_module, is_scheduled=true
+compare {
+ p.1.lhs = s32[] parameter(2)
+ p.1.rhs = s32[] parameter(3)
+ p.0.lhs = f32[] parameter(0)
+ p.0.rhs = f32[] parameter(1)
+ ROOT lhs = pred[] less-than(p.0.lhs, p.0.rhs)
+}
+
ENTRY Sort {
keys = f32[1024]{0} parameter(0)
values = s32[1024]{0} parameter(1)
- ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0}
+ ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0}, to_apply=compare
}
)"
@@ -1396,7 +1455,7 @@
protected:
// Expects "ToString(ParseHloString(string)) == string", that is, parses the
// string, asserts that it succeeded, stringifies the parsed module, and
- // checks that the it equals the original string.
+ // checks that it equals the original string.
void ExpectEqual() {
const string& original = GetParam().module_string;
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
@@ -1719,6 +1778,19 @@
// printed as "300".
}
+TEST_F(HloParserTest, ShortConstant) {
+ const string original = R"(HloModule ShortCOnstant_module
+
+ENTRY %ShortConstant.v4 () -> f32[67,89] {
+ ROOT %constant.1 = f32[67,89]{1,0} constant({...})
+}
+
+)";
+ auto result = ParseHloString(original);
+ TF_EXPECT_OK(result.status());
+ EXPECT_EQ(result.ValueOrDie()->ToString(HloPrintOptions()), original);
+}
+
TEST_F(HloParserTest, AttibutesAnyOrder) {
const string original = R"(HloModule any_order_module
diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
index c1f69db..6925dc3 100644
--- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
+++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
@@ -162,11 +162,11 @@
// For tuples, emit the full shape because the layout of a tuple is not
// represented in a single Layout field.
layout_string = ShapeUtil::HumanStringWithLayout(instruction->shape());
- } else {
- layout_string = StrCat(
- "{",
- absl::StrJoin(LayoutUtil::MinorToMajor(instruction->shape()), ","),
- "}");
+ } else if (instruction->shape().has_layout()) {
+ // For non-tuples, only emit the layout when the shape has a Layout.
+ // This extra check is required because LayoutUtil::HasLayout ignores
+ // token, opaque types etc.
+ layout_string = instruction->shape().layout().ToString();
}
attrs["layout"].set_s(layout_string);
}
diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
index 1e2b31a..498abcf 100644
--- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder_test.cc
@@ -17,6 +17,7 @@
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace hlo_graph_dumper {
@@ -178,6 +179,23 @@
EXPECT_GT(generator_.GetGraphDef().node_size(), 0);
}
+TEST_F(HloTfGraphBuilderTest, TokenHasNoLayout) {
+ auto builder = HloComputation::Builder("Token");
+ auto token = builder.AddInstruction(HloInstruction::CreateToken());
+ OpMetadata metadata;
+ metadata.set_op_name("x");
+ metadata.set_op_type("y");
+ token->set_metadata(metadata);
+ TF_ASSERT_OK(generator_.AddComputation(*builder.Build()));
+ GraphDef graph_def = generator_.GetGraphDef();
+ ASSERT_EQ(graph_def.node_size(), 1);
+ const auto &node = graph_def.node(0);
+ ASSERT_EQ(GetNodeAttr(node, "type").s(), "TOKEN");
+ ASSERT_EQ(GetNodeAttr(node, "layout").s(), "");
+ ASSERT_EQ(GetNodeAttr(node, "tf_op_name").s(), "x");
+ ASSERT_EQ(GetNodeAttr(node, "tf_op_type").s(), "y");
+}
+
} // namespace
} // namespace hlo_graph_dumper
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 4caaa5a..9d94c5c 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -50,6 +50,7 @@
case HloOpcode::kReduceWindow:
case HloOpcode::kScatter:
case HloOpcode::kSelectAndScatter:
+ case HloOpcode::kSort:
case HloOpcode::kFusion:
return true;
default:
@@ -376,6 +377,24 @@
get_tuple_element->tuple_index()));
}
+namespace {
+Status SameElementTypesForOperandsAndToApplyParameters(
+ const HloInstruction& instruction, int64 num_operands_to_check) {
+ const ProgramShape& to_apply = instruction.to_apply()->ComputeProgramShape();
+ for (int i = 0; i < num_operands_to_check; ++i) {
+ const Shape& parameter_shape = to_apply.parameters(i);
+ const Shape& operand_shape = instruction.operands()[i]->shape();
+ if (!ShapeUtil::SameElementType(parameter_shape, operand_shape)) {
+ return InvalidArgument(
+ "Shape mismatch between to_apply computation"
+ " parameter and operand %d in %s.",
+ i, instruction.ToString().c_str());
+ }
+ }
+ return Status::OK();
+}
+} // namespace
+
Status ShapeVerifier::HandleReduce(HloInstruction* reduce) {
if (reduce->operand_count() % 2 != 0) {
return InternalError(
@@ -387,9 +406,15 @@
for (const HloInstruction* operand : reduce->operands()) {
operand_shapes.push_back(&operand->shape());
}
- return CheckShape(reduce, ShapeInference::InferReduceShape(
- operand_shapes, reduce->dimensions(),
- reduce->to_apply()->ComputeProgramShape()));
+ TF_RETURN_IF_ERROR(
+ CheckShape(reduce, ShapeInference::InferReduceShape(
+ operand_shapes, reduce->dimensions(),
+ reduce->to_apply()->ComputeProgramShape())));
+
+ return allow_mixed_precision_
+ ? Status::OK()
+ : SameElementTypesForOperandsAndToApplyParameters(
+ *reduce, reduce->operands().size() - 1);
}
Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) {
@@ -545,19 +570,31 @@
// arbitrary map dimensions.
std::vector<int64> map_dims(max_operand_rank);
std::iota(map_dims.begin(), map_dims.end(), 0);
- return CheckShape(map, ShapeInference::InferMapShape(
- operand_shapes,
- map->to_apply()->ComputeProgramShape(), map_dims));
+
+ TF_RETURN_IF_ERROR(CheckShape(
+ map,
+ ShapeInference::InferMapShape(
+ operand_shapes, map->to_apply()->ComputeProgramShape(), map_dims)));
+
+ return allow_mixed_precision_
+ ? Status::OK()
+ : SameElementTypesForOperandsAndToApplyParameters(
+ *map, map->operands().size());
}
Status ShapeVerifier::HandleReduceWindow(HloInstruction* reduce_window) {
TF_RETURN_IF_ERROR(CheckOperandCount(reduce_window, 2));
- return CheckShape(
+ TF_RETURN_IF_ERROR(CheckShape(
reduce_window,
ShapeInference::InferReduceWindowShape(
reduce_window->operand(0)->shape(),
reduce_window->operand(1)->shape(), reduce_window->window(),
- reduce_window->to_apply()->ComputeProgramShape()));
+ reduce_window->to_apply()->ComputeProgramShape())));
+
+ return allow_mixed_precision_
+ ? Status::OK()
+ : SameElementTypesForOperandsAndToApplyParameters(*reduce_window,
+ 1);
}
Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) {
diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
index 4f69bd1..523890b 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc
@@ -552,5 +552,67 @@
HasSubstr("does not support non-array result"));
}
+static const char* const kMapOperandComputationMismatchHlo = R"(
+ HloModule MapOperandComputationMismatch
+
+ Computation {
+ param0 = f32[] parameter(0)
+ constant = f32[] constant(1)
+ ROOT add = f32[] add(param0, constant)
+ }
+
+ ENTRY kernelEntry {
+ param = f64[] parameter(0)
+ ROOT map = f32[] map(param), dimensions={}, to_apply=Computation
+})";
+
+TEST_F(HloVerifierTest, MapOperandComputationMismatch) {
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseHloString(kMapOperandComputationMismatchHlo));
+ auto status = verifier().Run(module.get()).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(
+ status.error_message(),
+ HasSubstr(
+ "Shape mismatch between to_apply computation parameter and operand"));
+}
+
+TEST_F(HloVerifierTestAllowMixedPrecision, MapOperandComputationMismatch) {
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseHloString(kMapOperandComputationMismatchHlo));
+ auto status = verifier().Run(module.get()).status();
+ ASSERT_TRUE(status.ok());
+}
+
+static const char* const kReduceOperandComputationMismatchHlo = R"(
+ HloModule ReduceOperandComputationMismatch
+ computation {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ ROOT add = f32[] add(x, y)
+ }
+
+ ENTRY kernelEntry {
+ arg0 = f16[64,64,224,224]{3,2,1,0} parameter(0)
+ constant = f16[] constant(0)
+ reduce = f16[64]{0} reduce(arg0, constant), dimensions={0,2,3}, to_apply=computation
+ })";
+
+TEST_F(HloVerifierTest, ReduceOperandComputationMismatch) {
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseHloString(kReduceOperandComputationMismatchHlo));
+ auto status = verifier().Run(module.get()).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.error_message(),
+ HasSubstr("Expected instruction to have shape equal to f32[64]"));
+}
+
+TEST_F(HloVerifierTestAllowMixedPrecision, ReduceOperandComputationMismatch) {
+ TF_ASSERT_OK_AND_ASSIGN(auto module,
+ ParseHloString(kReduceOperandComputationMismatchHlo));
+ auto status = verifier().Run(module.get()).status();
+ ASSERT_TRUE(status.ok());
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc
index c66eaec..3accecc 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.cc
@@ -113,20 +113,10 @@
Shape output_shape = output_array.GetShape();
Shape update_shape = update_array.GetShape();
- IndexGenerator start_indices_generator;
- // TODO(b/118437727): Remove the R1 path, and rename the variables.
- if (start_indices_array.GetShape().rank() == 1) {
- start_indices_generator = [&](int64 index) {
- return start_indices_array.EmitReadArrayElement(
- IrArray::Index({b->getInt64(index)}), b);
- };
- } else {
- start_indices_generator = [&](int64 index) {
- return operand_arrays[2 + index].EmitReadArrayElement(
- IrArray::Index(b->getInt64Ty()), b);
- };
- }
-
+ IndexGenerator start_indices_generator = [&](int64 index) {
+ return operand_arrays[2 + index].EmitReadArrayElement(
+ IrArray::Index(b->getInt64Ty()), b);
+ };
ElementGenerator update_array_generator = [&](const IrArray::Index& index) {
return update_array.EmitReadArrayElement(index, b);
};
@@ -178,21 +168,11 @@
TF_RETURN_IF_ERROR(dynamic_update_slice->Accept(&fused_emitter));
ElementGenerator update_array_generator = fused_emitter.GetGenerator(update);
- // TODO(b/118437727): Remove the R1 path, and rename the variables.
- IndexGenerator start_indices_generator;
- if (start_indices->shape().rank() == 1) {
- start_indices_generator = [&](int64 index) {
- return fused_emitter.GetGenerator(start_indices)(
- IrArray::Index({b->getInt64(index)}));
- };
- } else {
- start_indices_generator = [&](int64 index) {
- ElementGenerator element_generator =
- fused_emitter.GetGenerator(dynamic_update_slice->operand(2 + index));
- return element_generator(IrArray::Index(b->getInt64Ty()));
- };
- }
-
+ IndexGenerator start_indices_generator = [&](int64 index) {
+ ElementGenerator element_generator =
+ fused_emitter.GetGenerator(dynamic_update_slice->operand(2 + index));
+ return element_generator(IrArray::Index(b->getInt64Ty()));
+ };
bool is_signed = ShapeUtil::ElementIsSigned(start_indices->shape());
return EmitDynamicUpdateSliceInPlaceImpl(
update_shape, start_indices_generator, is_signed, update_array_generator,
diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h
index cf5083e..02c7195 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/ir_builder_mixin.h
@@ -270,6 +270,11 @@
}
template <class... Args>
+ llvm::Value* FCmpUNO(Args&&... args) {
+ return mixin_builder()->CreateFCmpUNO(std::forward<Args>(args)...);
+ }
+
+ template <class... Args>
llvm::Value* FDiv(Args&&... args) {
return mixin_builder()->CreateFDiv(std::forward<Args>(args)...);
}
diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
index fe320bb..3a35405 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_loop.cc
@@ -25,7 +25,6 @@
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc
index 0dc120e..a689881 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc
@@ -23,7 +23,6 @@
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 8343452..9bda6fb 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -368,6 +368,7 @@
const HloModuleProto* proto = module_protos[i];
const HloModuleConfig& config = *module_configs[i];
TF_ASSIGN_OR_RETURN(auto module, CreateModuleFromProto(*proto, config));
+ TF_RETURN_IF_ERROR(MaybeDumpUnoptimizedHloModule(*module));
module_group->push_back(std::move(module));
}
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 4680c92..a403428 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -534,6 +534,10 @@
p.edge_padding_high() +
std::max<int64>(operand_shape.dimensions(i) - 1, 0LL) *
p.interior_padding();
+ if (dimensions[i] < 0) {
+ return InvalidArgument("Padding result in negative size for dimension %d",
+ i);
+ }
is_dynamic[i] = operand_shape.is_dynamic_dimension(i);
}
@@ -1858,6 +1862,9 @@
fft_length[i]);
}
}
+ if (ShapeUtil::IsZeroElementArray(in)) {
+ return in;
+ }
Shape result = ShapeUtil::ChangeElementType(in, C64);
result.set_dimensions(result.dimensions_size() - 1,
fft_length[fft_rank - 1] / 2 + 1);
@@ -2433,7 +2440,7 @@
ShapeUtil::HumanString(arg));
}
- if (index >= arg.tuple_shapes_size()) {
+ if (index < 0 || index >= arg.tuple_shapes_size()) {
return InvalidArgument(
"Cannot infer shape: attempt to index out of tuple bounds: %d "
">= %d in shape %s.",
@@ -2716,13 +2723,26 @@
"Select's pred operand must have PRED element type; got %s.",
ShapeUtil::HumanString(pred));
}
- if (ShapeUtil::CompatibleIgnoringElementType(pred, on_true) ||
+ if (Shape::Equal()
+ .IgnoreElementType()
+ .IgnoreLayout()
+ .IgnoreDynamicDimension()(pred, on_true) ||
ShapeUtil::IsScalar(pred)) {
// By this stage we know that pred's element type is PRED. Therefore, this
// check restricts pred to be a PRED scalar, or a PRED array with the same
// dimensions as on_true and on_false.
- return ShapeUtil::ChangeElementType(
+ Shape inferred_shape = ShapeUtil::ChangeElementType(
on_true, ShapeUtil::HigherPrecisionElementType(on_true, on_false));
+
+ // Propagate dynamic dimensions if pred is not a scalar.
+ if (!ShapeUtil::IsScalar(pred)) {
+ for (int i = 0; i < inferred_shape.rank(); i++) {
+ if (pred.is_dynamic_dimension(i)) {
+ inferred_shape.set_dynamic_dimension(i, true);
+ }
+ }
+ }
+ return inferred_shape;
}
return InvalidArgument(
"Select operation with non-scalar predicate with dimensionality "
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index 26120a0..eabc223 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -896,6 +896,20 @@
ASSERT_TRUE(ShapeUtil::Equal(s32_, inferred1_status.ValueOrDie()));
}
+TEST_F(ShapeInferenceTest, InferTupleElementShapeOutOfBound) {
+ Shape tuple_shape = ShapeUtil::MakeTupleShape({f32_, s32_});
+ auto inferredNegative_status =
+ ShapeInference::InferGetTupleElementShape(tuple_shape, -1);
+ auto inferred2_status =
+ ShapeInference::InferGetTupleElementShape(tuple_shape, 2);
+ ASSERT_FALSE(inferredNegative_status.ok());
+ ASSERT_FALSE(inferred2_status.ok());
+ EXPECT_THAT(inferredNegative_status.status().error_message(),
+ HasSubstr("attempt to index out of tuple bounds"));
+ EXPECT_THAT(inferred2_status.status().error_message(),
+ HasSubstr("attempt to index out of tuple bounds"));
+}
+
TEST_F(ShapeInferenceTest, InferPowShape) {
auto ten_floats = ShapeUtil::MakeShape(F32, {10});
auto inferred_status = ShapeInference::InferBinaryOpShape(
@@ -1467,6 +1481,14 @@
Shape inferred_shape = inferred_status.ValueOrDie();
ASSERT_TRUE(
ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {39, 31}), inferred_shape));
+
+ dimension1->set_edge_padding_low(-20);
+ dimension1->set_edge_padding_high(-10);
+ auto negative_dimension_size = ShapeInference::InferPadShape(
+ input_shape, padding_value_shape, padding_config);
+ ASSERT_FALSE(negative_dimension_size.ok());
+ ASSERT_THAT(negative_dimension_size.status().error_message(),
+ HasSubstr("negative size for dimension 1"));
}
TEST_F(ShapeInferenceTest, Reverse) {
diff --git a/tensorflow/compiler/xla/service/sort_simplifier_test.cc b/tensorflow/compiler/xla/service/sort_simplifier_test.cc
index cd05fcf..a05bc79 100644
--- a/tensorflow/compiler/xla/service/sort_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/sort_simplifier_test.cc
@@ -34,13 +34,21 @@
const char* hlo_string = R"(
HloModule permutation_sort
- ENTRY sort_computation {
- keys = f32[64,8732]{1,0} parameter(0)
- values = s32[64,8732]{1,0} parameter(1)
- sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values),
- dimensions={1}
- ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0
- })";
+ compare {
+ p.0.lhs = f32[] parameter(0)
+ p.0.rhs = f32[] parameter(1)
+ p.1.lhs = s32[] parameter(2)
+ p.1.rhs = s32[] parameter(3)
+ ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
+ }
+
+ ENTRY sort_computation {
+ keys = f32[64,8732]{1,0} parameter(0)
+ values = s32[64,8732]{1,0} parameter(1)
+ sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values),
+ dimensions={1}, to_apply=compare
+ ROOT gte = f32[64,8732]{1,0} get-tuple-element(sort), index=0
+ })";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
@@ -58,17 +66,27 @@
const char* hlo_string = R"(
HloModule permutation_sort
- ENTRY sort_computation {
- keys = f32[64,87] parameter(0)
- values.0 = s32[64,87] parameter(1)
- values.1 = u32[64,87] parameter(2)
- sort = (f32[64,87], s32[64,87], u32[64,87]) sort(
- keys, values.0, values.1),
- dimensions={1}
- gte.0 = f32[64,87] get-tuple-element(sort), index=0
- gte.1 = u32[64,87] get-tuple-element(sort), index=2
- ROOT tuple = (f32[64,87], u32[64,87]) tuple(gte.0, gte.1)
- })";
+ compare {
+ p.0.lhs = f32[] parameter(0)
+ p.0.rhs = f32[] parameter(1)
+ p.1.lhs = s32[] parameter(2)
+ p.1.rhs = s32[] parameter(3)
+ p.2.lhs = u32[] parameter(4)
+ p.2.rhs = u32[] parameter(5)
+ ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
+ }
+
+ ENTRY sort_computation {
+ keys = f32[64,87] parameter(0)
+ values.0 = s32[64,87] parameter(1)
+ values.1 = u32[64,87] parameter(2)
+ sort = (f32[64,87], s32[64,87], u32[64,87]) sort(
+ keys, values.0, values.1),
+ dimensions={1}, to_apply=compare
+ gte.0 = f32[64,87] get-tuple-element(sort), index=0
+ gte.1 = u32[64,87] get-tuple-element(sort), index=2
+ ROOT tuple = (f32[64,87], u32[64,87]) tuple(gte.0, gte.1)
+ })";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
@@ -86,12 +104,20 @@
const char* hlo_string = R"(
HloModule permutation_sort
- ENTRY sort_computation {
- keys = f32[64,8732]{1,0} parameter(0)
- values = s32[64,8732]{1,0} parameter(1)
- sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), dimensions={1}
- ROOT gte = s32[64,8732]{1,0} get-tuple-element(sort), index=1
- })";
+ compare {
+ p.0.lhs = f32[] parameter(0)
+ p.0.rhs = f32[] parameter(1)
+ p.1.lhs = s32[] parameter(2)
+ p.1.rhs = s32[] parameter(3)
+ ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
+ }
+
+ ENTRY sort_computation {
+ keys = f32[64,8732]{1,0} parameter(0)
+ values = s32[64,8732]{1,0} parameter(1)
+ sort = (f32[64,8732]{1,0}, s32[64,8732]{1,0}) sort(keys, values), dimensions={1}, to_apply=compare
+ ROOT gte = s32[64,8732]{1,0} get-tuple-element(sort), index=1
+ })";
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_string));
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
index fd5759e..5516026 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc
@@ -19,6 +19,7 @@
#include <memory>
#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/instruction_fusion.h"
@@ -1065,14 +1066,16 @@
TEST_F(CanShareOperandBufferWithUserTest, SortCanShare) {
auto builder = HloComputation::Builder(TestName());
+ module_ = CreateNewVerifiedModule();
Shape keys_shape = ShapeUtil::MakeShape(F32, {8});
auto keys = builder.AddInstruction(
HloInstruction::CreateParameter(0, keys_shape, "keys"));
- auto sort =
- builder.AddInstruction(HloInstruction::CreateSort(keys_shape, 0, keys));
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto* sort, MakeSortHlo(keys_shape, {keys}, 0, &builder, module_.get()));
- BuildModuleAndRunAnalysis(builder.Build());
+ computation_ = module_->AddEntryComputation(builder.Build());
+ RunAnalysis();
EXPECT_TRUE(
points_to_analysis_->CanShareOperandBufferWithUser(keys, {}, sort, {}));
@@ -1080,6 +1083,7 @@
TEST_F(CanShareOperandBufferWithUserTest, SortCanShareWithTupleUser) {
auto builder = HloComputation::Builder(TestName());
+ module_ = CreateNewVerifiedModule();
Shape keys_shape = ShapeUtil::MakeShape(F32, {8});
Shape values_shape = ShapeUtil::MakeShape(F32, {8});
@@ -1087,11 +1091,13 @@
HloInstruction::CreateParameter(0, keys_shape, "keys"));
auto values = builder.AddInstruction(
HloInstruction::CreateParameter(1, values_shape, "values"));
- auto sort = builder.AddInstruction(HloInstruction::CreateSort(
- ShapeUtil::MakeTupleShape({keys_shape, values_shape}), 0, keys,
- {values}));
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto* sort,
+ MakeSortHlo(ShapeUtil::MakeTupleShape({keys_shape, values_shape}),
+ {keys, values}, 0, &builder, module_.get()));
- BuildModuleAndRunAnalysis(builder.Build());
+ computation_ = module_->AddEntryComputation(builder.Build());
+ RunAnalysis();
// The buffer for the keys can be shared with the first tuple entry.
EXPECT_TRUE(
diff --git a/tensorflow/compiler/xla/shape.h b/tensorflow/compiler/xla/shape.h
index e6b4e87..1d59490 100644
--- a/tensorflow/compiler/xla/shape.h
+++ b/tensorflow/compiler/xla/shape.h
@@ -181,6 +181,10 @@
bool ignore_dynamic_dimension_ = false;
};
+ // Test that all fields of the shape are the same, equivalent to Equal().
+ bool operator==(const Shape& other) const { return Equal()(*this, other); }
+ bool operator!=(const Shape& other) const { return !(*this == other); }
+
private:
// The element type of this shape (tuple, array, etc).
PrimitiveType element_type_ = PRIMITIVE_TYPE_INVALID;
diff --git a/tensorflow/compiler/xla/shape_test.cc b/tensorflow/compiler/xla/shape_test.cc
index 55ce5fe..526abaf 100644
--- a/tensorflow/compiler/xla/shape_test.cc
+++ b/tensorflow/compiler/xla/shape_test.cc
@@ -85,6 +85,24 @@
EXPECT_EQ("f32[<=23,44,55]", array_shape.ToString());
}
+TEST_F(ShapeTest, EqualityTest) {
+ // Different layouts.
+ EXPECT_NE(ShapeUtil::MakeShapeWithLayout(F32, {23, 44}, {1, 0}),
+ ShapeUtil::MakeShapeWithLayout(F32, {23, 44}, {0, 1}));
+
+ // Different dims.
+ EXPECT_NE(ShapeUtil::MakeShapeWithLayout(F32, {44, 23}, {1, 0}),
+ ShapeUtil::MakeShapeWithLayout(F32, {23, 44}, {1, 0}));
+
+ // Different elements.
+ EXPECT_NE(ShapeUtil::MakeShapeWithLayout(S32, {44, 23}, {1, 0}),
+ ShapeUtil::MakeShapeWithLayout(F32, {23, 44}, {1, 0}));
+
+ // Equal shapes.
+ EXPECT_EQ(ShapeUtil::MakeShapeWithLayout(F32, {23, 44}, {1, 0}),
+ ShapeUtil::MakeShapeWithLayout(F32, {23, 44}, {1, 0}));
+}
+
TEST_F(ShapeTest, IsStatic) {
EXPECT_TRUE(opaque_.is_static());
EXPECT_TRUE(token_.is_static());
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 8fb6742..e8e779f 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -670,22 +670,19 @@
xla_test(
name = "exhaustive_f32_elementwise_op_test",
- size = "enormous",
srcs = ["exhaustive_f32_elementwise_op_test.cc"],
- backends = [
- "cpu",
- "gpu",
- ],
+ real_hardware_only = True, # Very slow on the interpreter.
shard_count = 48,
tags = [
- "broken",
- "manual",
- "notap",
+ "optonly",
+ # This is a big test that we skip for capacity reasons in OSS testing.
+ "nooss",
],
deps = [
":client_library_test_base",
":literal_test_util",
"//tensorflow/compiler/xla/client:xla_builder",
+ "//tensorflow/compiler/xla/client/lib:math",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"@com_google_absl//absl/base",
diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
index 7379fbc..acdd3c9 100644
--- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
@@ -35,7 +35,6 @@
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl
index 05d4d04..c14d279 100644
--- a/tensorflow/compiler/xla/tests/build_defs.bzl
+++ b/tensorflow/compiler/xla/tests/build_defs.bzl
@@ -34,6 +34,7 @@
xla_test_library_deps = [],
backends = [],
blacklisted_backends = [],
+ real_hardware_only = False,
args = [],
tags = [],
copts = [],
@@ -108,6 +109,10 @@
use for that target.
**kwargs: Additional keyword arguments to pass to native.cc_test.
"""
+
+ # All of the backends in all_backends are real hardware.
+ _ignore = [real_hardware_only]
+
test_names = []
if not backends:
backends = all_backends
diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc
index 87e912f..b961e61 100644
--- a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc
+++ b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc
@@ -13,7 +13,9 @@
limitations under the License.
==============================================================================*/
+#include <cmath>
#include "absl/base/casts.h"
+#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
@@ -21,12 +23,22 @@
namespace xla {
namespace {
+
class ExhaustiveF32ElementwiseOpTest
: public ClientLibraryTestBase,
public ::testing::WithParamInterface<std::pair<int64, int64>> {
protected:
ErrorSpec error_spec_{0.0001, 0.0001};
+ bool IsClose(float expected, float actual) {
+ float abs_err = std::abs(expected - actual);
+ float rel_err = abs_err / std::abs(expected);
+ return abs_err < error_spec_.abs || rel_err < error_spec_.rel ||
+ (std::isnan(expected) && std::isnan(actual)) ||
+ (std::isinf(expected) && std::isinf(actual) &&
+ (expected > 0) == (actual > 0));
+ }
+
template <typename EnqueueOpTy>
void ExhaustivelyTestF32Op(EnqueueOpTy enqueue_op,
float (*evaluate_op)(float),
@@ -104,33 +116,60 @@
// b) we can print out better error messages (namely, we can print out
// which floating-point value input failed, while LiteralTestUtil::Near
// can only print out the input index that failed).
+ // c) we need special handling of certain inputs. For example, we say that
+ // a denormal input has multiple correct outputs (namely, f(x) and f(0))
+ // and just needs to be close to one of them.
absl::Span<float> result_arr = result_literal.data<float>();
ASSERT_EQ(result_arr.size(), input_arr.size());
int64 mismatches = 0;
+ // Hoisting this out of the loop is a nice speedup on shards that have many
+ // denormals.
+ const float expected_at_zero = evaluate_op(0);
for (int64 i = 0; i < input_arr.size(); ++i) {
float input = ith_input_elem(i);
- float expected = evaluate_op(input);
float actual = result_arr[i];
- float abs_err = std::abs(expected - actual);
- float rel_err = abs_err / std::abs(expected);
- if (abs_err < error_spec_.abs || rel_err < error_spec_.rel ||
- (std::isnan(expected) && std::isnan(actual)) ||
- (std::isinf(expected) && std::isinf(actual) &&
- (expected > 0) == (actual > 0))) {
- // Successful match! Nothing to do.
+ float expected = evaluate_op(input);
+ if (IsClose(expected, actual)) {
+ continue;
+ }
+
+ constexpr int64 kMaxMismatchesPrinted = 1000;
+ if (std::fpclassify(input) == FP_SUBNORMAL) {
+ // For denormal inputs, we accept answers that are close to either
+ // - evaluate_op(input) OR
+ // - evaluate_op(0).
+ if (IsClose(expected_at_zero, actual)) {
+ continue;
+ }
+ ++mismatches;
+ if (mismatches < kMaxMismatchesPrinted || VLOG_IS_ON(2)) {
+ // Use %0.9g because that's guaranteed to print an f32 to full
+ // precision.
+ LOG(ERROR) << absl::StreamFormat(
+ "Mismatch on denormal value %0.9g (0x%08x). Expected either "
+ "%0.9g (0x%08x) (evaluated at true value) or %0.9g (0x%08x) "
+ "(evaluated at zero), but got %0.9g (0x%08x).",
+ input, absl::bit_cast<uint32>(input), //
+ expected, absl::bit_cast<uint32>(expected), //
+ expected_at_zero, absl::bit_cast<uint32>(expected_at_zero),
+ actual, absl::bit_cast<uint32>(actual));
+ }
} else {
- constexpr int64 kMaxMismatchesPrinted = 1000;
mismatches++;
if (mismatches < kMaxMismatchesPrinted || VLOG_IS_ON(2)) {
- LOG(ERROR) << "Mismatch on " << input << " (0x"
- << absl::StrCat(absl::Hex(input, absl::kZeroPad8))
- << "). Expected " << expected << ", but got " << actual;
+ LOG(ERROR) << absl::StreamFormat(
+ "Mismatch on %0.9g (0x%08x). Expected %0.9g (0x%08x), but got "
+ "%0.9g (0x%08x).",
+ input, absl::bit_cast<uint32>(input), //
+ expected, absl::bit_cast<uint32>(expected), //
+ actual, absl::bit_cast<uint32>(actual));
}
- if (mismatches == kMaxMismatchesPrinted && !VLOG_IS_ON(2)) {
- LOG(ERROR) << "Not printing any more mismatches; pass "
- "--vmodule=exhaustive_f32_elementwise_op_test=2 to see "
- "all of them.";
- }
+ }
+
+ if (mismatches == kMaxMismatchesPrinted && !VLOG_IS_ON(2)) {
+ LOG(ERROR) << "Not printing any more mismatches; pass "
+ "--vmodule=exhaustive_f32_elementwise_op_test=2 to see "
+ "all of them.";
}
}
EXPECT_EQ(mismatches, 0);
@@ -138,18 +177,12 @@
};
XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, LogF32) {
-#ifdef XLA_TEST_BACKEND_CPU
- // TODO(b/73141998): The vectorized Log implementation gives results outside
- // our error spec in this range (these numbers are bitwise representations of
- // floats expressed as a zero extended int64).
- std::pair<int64, int64> known_incorrect_range = {1, 8388608};
-#else
- std::pair<int64, int64> known_incorrect_range = {0, 0};
+#if !defined(XLA_TEST_BACKEND_CPU) && !defined(XLA_TEST_BACKEND_GPU)
+ error_spec_ = ErrorSpec{0.001, 0.001};
#endif
-
ExhaustivelyTestF32Op(
[](XlaBuilder* builder, const XlaOp& input) { Log(input); }, std::log,
- known_incorrect_range);
+ /*known_incorrect_range=*/{0, 0});
}
XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, ExpF32) {
@@ -174,6 +207,18 @@
/*known_incorrect_range=*/{0, 0});
}
+XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, ErfF32) {
+ ExhaustivelyTestF32Op(
+ [](XlaBuilder* builder, const XlaOp& input) { Erf(input); }, std::erf,
+ /*known_incorrect_range=*/{0, 0});
+}
+
+XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, ErfcF32) {
+ ExhaustivelyTestF32Op(
+ [](XlaBuilder* builder, const XlaOp& input) { Erfc(input); }, std::erfc,
+ /*known_incorrect_range=*/{0, 0});
+}
+
std::vector<std::pair<int64, int64>> CreateExhaustiveParameters() {
// We break up the 2^32-element space into small'ish chunks to keep peak
// memory usage low.
diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc
index f80d29b..e2cf4c0 100644
--- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc
@@ -34,7 +34,6 @@
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc
index 591d6c1..321c3fb 100644
--- a/tensorflow/compiler/xla/tests/test_utils_test.cc
+++ b/tensorflow/compiler/xla/tests/test_utils_test.cc
@@ -136,10 +136,18 @@
auto module = ParseHloString(R"(
HloModule sort.148.1589
+compare {
+ p.0.lhs = f32[] parameter(0)
+ p.0.rhs = f32[] parameter(1)
+ p.1.lhs = s32[] parameter(2)
+ p.1.rhs = s32[] parameter(3)
+ ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
+}
+
ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> (f32[1048576], s32[1048576]) {
%parameter.0 = f32[1048576]{0} parameter(0)
%parameter.1 = s32[1048576]{0} parameter(1)
- ROOT %sort.148.1589 = (f32[1048576]{0}, s32[1048576]{0}) sort(f32[1048576]{0} %parameter.0, s32[1048576]{0} %parameter.1), dimensions={0}
+ ROOT %sort.148.1589 = (f32[1048576]{0}, s32[1048576]{0}) sort(f32[1048576]{0} %parameter.0, s32[1048576]{0} %parameter.1), dimensions={0}, to_apply=compare
}
)")
.ValueOrDie();
@@ -159,10 +167,18 @@
auto module = ParseHloString(R"(
HloModule sort.148.1589
+compare {
+ p.0.lhs = s32[] parameter(0)
+ p.0.rhs = s32[] parameter(1)
+ p.1.lhs = s32[] parameter(2)
+ p.1.rhs = s32[] parameter(3)
+ ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
+}
+
ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> (s32[1048576], s32[1048576]) {
%parameter.0 = s32[1048576]{0} parameter(0)
%parameter.1 = s32[1048576]{0} parameter(1)
- ROOT %sort.148.1589 = (s32[1048576]{0}, s32[1048576]{0}) sort(s32[1048576]{0} %parameter.0, s32[1048576]{0} %parameter.1), dimensions={0}
+ ROOT %sort.148.1589 = (s32[1048576]{0}, s32[1048576]{0}) sort(s32[1048576]{0} %parameter.0, s32[1048576]{0} %parameter.1), dimensions={0}, to_apply=compare
}
)")
.ValueOrDie();
@@ -182,10 +198,18 @@
auto module = ParseHloString(R"(
HloModule sort, is_scheduled=true
+compare {
+ p.0.lhs = bf16[] parameter(0)
+ p.0.rhs = bf16[] parameter(1)
+ p.1.lhs = s32[] parameter(2)
+ p.1.rhs = s32[] parameter(3)
+ ROOT lt = pred[] less-than(p.0.lhs, p.0.rhs)
+}
+
ENTRY %sort. (parameter.0: bf16[2,1452], parameter.1: s32[2,1452]) -> (bf16[2,1452], s32[2,1452]) {
%parameter.0 = bf16[2,1452]{1,0} parameter(0)
%parameter.1 = s32[2,1452]{1,0} parameter(1)
- ROOT %sort = (bf16[2,1452]{1,0}, s32[2,1452]{1,0}) sort(bf16[2,1452]{1,0} %parameter.0, s32[2,1452]{1,0} %parameter.1), dimensions={1}
+ ROOT %sort = (bf16[2,1452]{1,0}, s32[2,1452]{1,0}) sort(bf16[2,1452]{1,0} %parameter.0, s32[2,1452]{1,0} %parameter.1), dimensions={1}, to_apply=compare
}
)")
.ValueOrDie();
diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc
index 4fbd7f2..c51f30f 100644
--- a/tensorflow/compiler/xla/tests/unary_op_test.cc
+++ b/tensorflow/compiler/xla/tests/unary_op_test.cc
@@ -64,7 +64,9 @@
&builder, {-2, 25, 0, static_cast<T>(-0.0), -123, inf<T>(), -inf<T>()});
Sign(arg);
- ComputeAndCompareR1<T>(&builder, {-1, 1, 0, 0, -1, 1, -1}, {});
+ ComputeAndCompareR1<T>(
+ &builder,
+ {-1, 1, static_cast<T>(+0.0), static_cast<T>(-0.0), -1, 1, -1}, {});
}
template <typename T>
diff --git a/tensorflow/compiler/xla/text_literal_writer.cc b/tensorflow/compiler/xla/text_literal_writer.cc
index 7289ae7..fc7949d 100644
--- a/tensorflow/compiler/xla/text_literal_writer.cc
+++ b/tensorflow/compiler/xla/text_literal_writer.cc
@@ -24,7 +24,6 @@
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc
index 4375e7c..df2d3d1 100644
--- a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc
+++ b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc
@@ -31,7 +31,6 @@
#include "tensorflow/compiler/xla/service/service.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc
index 7235698..35bb82c 100644
--- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc
+++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc
@@ -26,7 +26,6 @@
#include "tensorflow/compiler/xla/service/service.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc
index c01a47b..21217c2 100644
--- a/tensorflow/compiler/xla/tools/replay_computation.cc
+++ b/tensorflow/compiler/xla/tools/replay_computation.cc
@@ -90,8 +90,8 @@
int num_runs = 1;
};
-std::unique_ptr<LocalExecutable> CompileExecutable(const HloSnapshot& module,
- LocalClient* client) {
+StatusOr<std::unique_ptr<LocalExecutable>> CompileExecutable(
+ const HloSnapshot& module, LocalClient* client) {
XlaComputation computation(module.hlo().hlo_module());
std::vector<Shape> argument_layouts;
argument_layouts.reserve(
@@ -102,9 +102,8 @@
argument_layouts.push_back(Shape(param));
argument_layout_ptrs.push_back(&argument_layouts.back());
}
- return client
- ->Compile(computation, argument_layout_ptrs, ExecutableBuildOptions())
- .ValueOrDie();
+ return client->Compile(computation, argument_layout_ptrs,
+ ExecutableBuildOptions());
}
absl::optional<Shape> GetXfeedShape(bool is_infeed,
@@ -357,7 +356,7 @@
// Compile all the modules in parallel.
LOG(INFO) << "Compiling " << snapshots.size() << " modules in parallel.";
- std::vector<std::unique_ptr<LocalExecutable>> executables;
+ std::vector<StatusOr<std::unique_ptr<LocalExecutable>>> executables;
{
// ThreadPool CHECK-fails if we give it 0 threads.
tensorflow::thread::ThreadPool thread_pool(
@@ -374,7 +373,12 @@
LOG(INFO) << "Done compiling; now running the modules.";
for (int64 i = 0; i < executables.size(); ++i) {
- LocalExecutable* executable = executables[i].get();
+ if (!executables[i].ok()) {
+ LOG(ERROR) << "Compilation failed: " << executables[i].status();
+ exit_status = EXIT_FAILURE;
+ continue;
+ }
+ LocalExecutable* executable = executables[i].ValueOrDie().get();
LOG(ERROR) << "Running iteration " << i;
StatusOr<Literal> result_status =
ReplayComputation(snapshots[i], executable, client, opts);
diff --git a/tensorflow/compiler/xla/tools/show_signature.cc b/tensorflow/compiler/xla/tools/show_signature.cc
index cdf306d..b80d0db 100644
--- a/tensorflow/compiler/xla/tools/show_signature.cc
+++ b/tensorflow/compiler/xla/tools/show_signature.cc
@@ -37,7 +37,6 @@
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
diff --git a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc
index 2ee1a6c..b791519 100644
--- a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc
+++ b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc
@@ -68,9 +68,11 @@
Status CompilationCacheKey(const xrt::XLAComputation& computation,
string* key) {
- string serialized;
- TF_RET_CHECK(SerializeToStringDeterministic(computation, &serialized));
- uint64 fingerprint = Fingerprint64(serialized);
+ const size_t size = computation.ByteSizeLong();
+ auto serialized = absl::make_unique<char[]>(size);
+ TF_RET_CHECK(
+ SerializeToBufferDeterministic(computation, serialized.get(), size));
+ uint64 fingerprint = Fingerprint64(absl::string_view(serialized.get(), size));
*key = absl::StrCat(fingerprint);
return Status::OK();
}
diff --git a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc
index 116c193..42ef881 100644
--- a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc
+++ b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc
@@ -23,7 +23,6 @@
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/compiler/xrt/xrt.pb.h"
#include "tensorflow/compiler/xrt/xrt_compilation_cache.h"
#include "tensorflow/compiler/xrt/xrt_device.h"
diff --git a/tensorflow/compiler/xrt/xrt_state.cc b/tensorflow/compiler/xrt/xrt_state.cc
index 1e2a958..78a1b6a 100644
--- a/tensorflow/compiler/xrt/xrt_state.cc
+++ b/tensorflow/compiler/xrt/xrt_state.cc
@@ -31,7 +31,6 @@
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/random/random.h"
diff --git a/tensorflow/contrib/bigtable/ops/bigtable_ops.cc b/tensorflow/contrib/bigtable/ops/bigtable_ops.cc
index 416b719..39c2a2e 100644
--- a/tensorflow/contrib/bigtable/ops/bigtable_ops.cc
+++ b/tensorflow/contrib/bigtable/ops/bigtable_ops.cc
@@ -59,7 +59,7 @@
.Input("table: resource")
.Input("prefix: string")
.Output("handle: variant")
- .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ .SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetShapeFn(shape_inference::ScalarShape);
@@ -68,14 +68,14 @@
.Input("start_key: string")
.Input("end_key: string")
.Output("handle: variant")
- .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ .SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("BigtableSampleKeysDataset")
.Input("table: resource")
.Output("handle: variant")
- .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ .SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetShapeFn(shape_inference::ScalarShape);
@@ -85,7 +85,7 @@
.Input("start_key: string")
.Input("end_key: string")
.Output("handle: variant")
- .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ .SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetShapeFn(shape_inference::ScalarShape);
@@ -100,7 +100,7 @@
.Input("columns: string")
.Input("probability: float")
.Output("handle: variant")
- .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ .SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetShapeFn(shape_inference::ScalarShape);
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD
index d3b23d9..64e4c45 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD
@@ -193,8 +193,9 @@
py_test(
name = "estimator_test",
- size = "large",
+ size = "medium",
srcs = ["estimator_test.py"],
+ shard_count = 4,
srcs_version = "PY2AND3",
tags = [
"no_gpu",
diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt
index 21ae9a0..8b63953 100644
--- a/tensorflow/contrib/cmake/python_modules.txt
+++ b/tensorflow/contrib/cmake/python_modules.txt
@@ -13,6 +13,7 @@
tensorflow/core/lib/core
tensorflow/core/profiler
tensorflow/core/protobuf
+tensorflow/core/protobuf/tpu
tensorflow/core/util
tensorflow/examples
tensorflow/examples/tutorials
@@ -437,7 +438,6 @@
tensorflow/contrib/tpu
tensorflow/contrib/tpu/ops
tensorflow/contrib/tpu/profiler
-tensorflow/contrib/tpu/proto
tensorflow/contrib/tpu/python
tensorflow/contrib/tpu/python/ops
tensorflow/contrib/tpu/python/profiler
diff --git a/tensorflow/contrib/cmake/python_protos.txt b/tensorflow/contrib/cmake/python_protos.txt
index 013180c..b460320 100644
--- a/tensorflow/contrib/cmake/python_protos.txt
+++ b/tensorflow/contrib/cmake/python_protos.txt
@@ -1,6 +1,7 @@
tensorflow/core
tensorflow/core/kernels/boosted_trees
tensorflow/core/profiler
+tensorflow/core/protobuf/tpu
tensorflow/python
tensorflow/contrib/boosted_trees/proto
tensorflow/contrib/cloud/kernels
@@ -12,7 +13,6 @@
tensorflow/contrib/session_bundle
tensorflow/contrib/tensor_forest/proto
tensorflow/contrib/tensorboard/plugins/projector
-tensorflow/contrib/tpu/proto
tensorflow/contrib/tpu/profiler
tensorflow/contrib/training/python/training
tensorflow/contrib/verbs
diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake
index d8d1cc3..cc263d7 100644
--- a/tensorflow/contrib/cmake/tf_core_framework.cmake
+++ b/tensorflow/contrib/cmake/tf_core_framework.cmake
@@ -125,9 +125,9 @@
file(GLOB_RECURSE tf_protos_cc_srcs RELATIVE ${tensorflow_source_dir}
"${tensorflow_source_dir}/tensorflow/core/*.proto"
+ "${tensorflow_source_dir}/tensorflow/core/protobuf/tpu/*.proto"
"${tensorflow_source_dir}/tensorflow/compiler/xla/*.proto"
"${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/proto/*.proto"
- "${tensorflow_source_dir}/tensorflow/contrib/tpu/proto/*.proto"
)
RELATIVE_PROTOBUF_GENERATE_CPP(PROTO_SRCS PROTO_HDRS
diff --git a/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py b/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py
index 6c5f8c6..4db711c 100644
--- a/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py
@@ -25,11 +25,13 @@
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
+@test_util.run_v1_only("deprecated API, no eager or V2 test coverage")
class AssertElementShapeTest(test_base.DatasetTestBase):
def test_assert_element_shape(self):
diff --git a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py
index b9840b1..220f993 100644
--- a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py
@@ -27,12 +27,14 @@
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
+from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
from tensorflow.python.util import compat
prefix_path = "tensorflow/core/lib"
+@test_util.run_v1_only("deprecated API, no eager or V2 test coverage")
class LMDBDatasetTest(test_base.DatasetTestBase):
def setUp(self):
diff --git a/tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py
index e7281d5..78019fc 100644
--- a/tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py
@@ -25,10 +25,12 @@
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
+@test_util.run_v1_only("deprecated API, no eager or V2 test coverage")
class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
index 2527706..9275a36 100644
--- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
@@ -26,11 +26,13 @@
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
+@test_util.run_v1_only("deprecated API, no eager or V2 test coverage")
class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index c67af0e..f4c9e00 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -131,24 +131,25 @@
],
)
+cuda_py_test(
+ name = "one_device_strategy_test",
+ srcs = ["one_device_strategy_test.py"],
+ additional_deps = [
+ ":strategy_test_lib",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python/distribute:one_device_strategy",
+ "//tensorflow/python/eager:test",
+ ],
+)
+
py_library(
name = "collective_all_reduce_strategy",
srcs = ["collective_all_reduce_strategy.py"],
visibility = ["//tensorflow:internal"],
deps = [
- ":mirrored_strategy",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:collective_ops",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:training",
- "//tensorflow/python/distribute:cross_device_ops",
- "//tensorflow/python/distribute:cross_device_utils",
- "//tensorflow/python/distribute:input_lib",
- "//tensorflow/python/distribute:multi_worker_util",
- "//tensorflow/python/distribute:numpy_dataset",
- "//tensorflow/python/distribute:values",
- "//tensorflow/python/eager:context",
+ "//tensorflow/python/distribute:collective_all_reduce_strategy",
+ "//tensorflow/python/distribute:distribute_lib",
+ "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
],
)
@@ -202,18 +203,6 @@
],
)
-py_test(
- name = "one_device_strategy_test",
- srcs = ["one_device_strategy_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":strategy_test_lib",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python/distribute:one_device_strategy",
- "//tensorflow/python/eager:test",
- ],
-)
-
# TODO(priyag): Rename this test to mirrored_strategy_test
cuda_py_test(
name = "mirrored_strategy_multigpu_test",
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
index aa4d82b..1974162 100644
--- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
@@ -18,30 +18,18 @@
from __future__ import division
from __future__ import print_function
-import copy
-
-from tensorflow.contrib.distribute.python import mirrored_strategy
-from tensorflow.core.protobuf import rewriter_config_pb2
-from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
-from tensorflow.python.distribute import cross_device_utils
-from tensorflow.python.distribute import device_util
+from tensorflow.python.distribute import collective_all_reduce_strategy
from tensorflow.python.distribute import distribute_lib
-from tensorflow.python.distribute import input_lib
-from tensorflow.python.distribute import multi_worker_util
-from tensorflow.python.distribute import numpy_dataset
-from tensorflow.python.distribute import values
-from tensorflow.python.eager import context
-from tensorflow.python.eager import tape
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import collective_ops
-from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
+from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
# TODO(yuefengz): support in-graph replication.
class CollectiveAllReduceStrategy(distribute_lib.DistributionStrategy):
"""Distribution strategy that uses collective ops for all-reduce.
+ *** contrib version ***
+
It is similar to the MirroredStrategy but it uses collective ops for
reduction.
@@ -64,311 +52,19 @@
CollectiveAllReduceExtended(self, num_gpus_per_worker))
-class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
+class CollectiveAllReduceExtended(
+ collective_all_reduce_strategy.CollectiveAllReduceExtended):
"""Implementation of CollectiveAllReduceStrategy."""
def __init__(self, container_strategy, num_gpus_per_worker):
- distribute_lib.DistributionStrategyExtended.__init__(
- self, container_strategy)
- self._cross_device_ops = None
- self._num_gpus_per_worker = num_gpus_per_worker
- self._initialize_local_worker(num_gpus_per_worker)
- assert isinstance(self._get_cross_device_ops(),
- cross_device_ops_lib.CollectiveAllReduce)
-
- def _initialize_local_worker(self, num_gpus_per_worker):
- """Initializes the object for local training."""
- self._is_chief = True
- self._num_workers = 1
-
- if num_gpus_per_worker:
- local_devices = tuple(
- "/device:GPU:%d" % i for i in range(num_gpus_per_worker)
- )
- else:
- local_devices = ("/device:CPU:0",)
- self._worker_device = device_util.canonicalize("/device:CPU:0")
- self._host_input_device = numpy_dataset.SingleDevice(self._worker_device)
-
- self._collective_keys = cross_device_utils.CollectiveKeys()
- self._initialize_local(local_devices)
- # TODO(yuefengz): remove num_gpus_per_worker from CollectiveAllReduce.
- self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
- num_workers=self._num_workers,
- num_gpus_per_worker=num_gpus_per_worker,
- collective_keys=self._collective_keys)
-
- self._cluster_spec = None
- self._task_type = None
- self._task_id = None
-
- logging.info("CollectiveAllReduceStrategy with local_devices = %r",
- local_devices)
-
- def _initialize_multi_worker(self, num_gpus_per_worker, cluster_spec,
- task_type, task_id):
- """Initializes the object for multi-worker training."""
- if task_type is None or task_id is None:
- raise ValueError("When `cluster_spec` is given, you must also specify "
- "`task_type` and `task_id`")
- if task_type not in ("chief", "worker"):
- raise ValueError(
- "Unrecognized task_type: %r, valid task types are: \"chief\", "
- "\"worker\"." % task_type)
- cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
- self._num_workers = multi_worker_util.worker_count(cluster_spec, task_type)
- if not self._num_workers:
- raise ValueError("No `worker` or `chief` tasks can be found in "
- "`cluster_spec`.")
-
- self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type,
- task_id)
-
- self._worker_device = "/job:%s/task:%d" % (task_type, task_id)
- self._host_input_device = numpy_dataset.SingleDevice(self._worker_device)
- if num_gpus_per_worker:
- local_devices = tuple(
- "%s/device:GPU:%d" % (self._worker_device, i)
- for i in range(num_gpus_per_worker)
- )
- else:
- local_devices = (self._worker_device,)
-
- self._collective_keys = cross_device_utils.CollectiveKeys()
- self._initialize_local(local_devices)
- self._input_workers = input_lib.InputWorkers(
- self._device_map, [(self._worker_device, self.worker_devices)])
- self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
- num_workers=self._num_workers,
- num_gpus_per_worker=num_gpus_per_worker,
- collective_keys=self._collective_keys)
-
- # Add a default device so that ops without specified devices will not end up
- # on other workers.
- self._default_device = "/job:%s/task:%d" % (task_type, task_id)
-
- self._cluster_spec = multi_worker_util.normalize_cluster_spec(cluster_spec)
- self._task_type = task_type
- self._task_id = task_id
-
- logging.info(
- "Multi-worker CollectiveAllReduceStrategy with "
- "cluster_spec = %r, task_type = %r, task_id = %r, "
- "num_workers = %r, local_devices = %r", cluster_spec.as_dict(),
- task_type, task_id, self._num_workers, local_devices)
-
- def _create_variable(self, next_creator, *args, **kwargs):
- colocate_with = kwargs.pop("colocate_with", None)
- if colocate_with is None:
- device_map = self._device_map
- logical_device = 0 # TODO(josh11b): Get logical device from scope here.
- elif isinstance(colocate_with, numpy_dataset.SingleDevice):
- with ops.device(colocate_with.device):
- return next_creator(*args, **kwargs)
- else:
- device_map = colocate_with.device_map
- logical_device = colocate_with.logical_device
-
- def _real_mirrored_creator(devices, *args, **kwargs):
- """Creates one MirroredVariable on the current worker."""
- unique_var_name = ops.get_default_graph().unique_name(
- kwargs["name"], mark_as_used=False).rstrip("/")
- # pylint: disable=protected-access
- collective_instance_key = self._collective_keys.get_instance_key(
- key_id=unique_var_name)
- # Only the first device participles in the broadcast of initial values.
- group_key = self._collective_keys.get_group_key([devices[0]])
- group_size = self._num_workers
- if "initial_value" not in kwargs:
- raise ValueError("Initial value must be specified.")
- initial_value = kwargs["initial_value"]
- if callable(initial_value):
- initial_value_fn = initial_value
- else:
- initial_value_fn = lambda: initial_value
-
- value_list = []
- for i, d in enumerate(devices):
- with ops.init_scope(), ops.device(d):
- if i == 0:
- # The initial value fn makes sure variables all initialized to
- # same values. The first device of the chief worker will send their
- # variable values to other workers.
- def _overridden_initial_value_fn(device=d, index=i): # pylint: disable=g-missing-docstring
- with ops.device(device):
- initial_value = initial_value_fn()
- assert not callable(initial_value)
- initial_value = ops.convert_to_tensor(initial_value)
-
- assert index == 0, index
- if self._num_workers > 1:
- if self._is_chief:
- bcast_send = collective_ops.broadcast_send(
- initial_value, initial_value.shape, initial_value.dtype,
- group_size, group_key, collective_instance_key)
- with ops.control_dependencies([bcast_send]):
- return array_ops.identity(initial_value)
- else:
- return collective_ops.broadcast_recv(
- initial_value.shape, initial_value.dtype, group_size,
- group_key, collective_instance_key)
- return initial_value
- else:
- # Give replicas meaningful distinct names:
- var0name = value_list[0].name.split(":")[0]
- # We append a / to variable names created on replicas with id > 0 to
- # ensure that we ignore the name scope and instead use the given
- # name as the absolute name of the variable.
- kwargs["name"] = "%s/replica_%d/" % (var0name, i)
-
- # Variables on non-first replica get initial values from the
- # variables created on the first device of each worker.
- def _overridden_initial_value_fn(device=d, index=i):
- assert index > 0
- with ops.device(device):
- if context.executing_eagerly():
- return array_ops.identity(value_list[0].value())
- else:
- return array_ops.identity(value_list[0].initial_value)
-
- kwargs["initial_value"] = _overridden_initial_value_fn
- with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
- # Don't record operations (e.g. other variable reads) during
- # variable creation.
- with tape.stop_recording():
- v = next_creator(*args, **kwargs)
-
- if i == 0:
- actual_var_name = v.name.split(":")[0]
- assert unique_var_name == actual_var_name, "%r vs %r" % (
- unique_var_name, actual_var_name)
- assert not isinstance(v, values.DistributedVariable)
- value_list.append(v)
- return value_list
-
- # pylint: disable=protected-access
- return mirrored_strategy._create_mirrored_variable(
- self._container_strategy(), device_map, logical_device,
- _real_mirrored_creator, *args, **kwargs)
-
- def _make_dataset_iterator(self, dataset):
- return input_lib.DatasetIterator(dataset, self._input_workers,
- self._num_replicas_in_sync)
-
- def _make_input_fn_iterator(
- self,
- input_fn,
- replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
- """Distributes the dataset to each local GPU."""
- if self._cluster_spec is None:
- input_pipeline_id = 0
- else:
- input_pipeline_id = multi_worker_util.id_in_cluster(
- self._cluster_spec, self._task_type, self._task_id)
- input_context = distribute_lib.InputContext(
- num_input_pipelines=self._num_workers,
- input_pipeline_id=input_pipeline_id,
- num_replicas_in_sync=self._num_replicas_in_sync)
-
- return input_lib.InputFunctionIterator(
- input_fn, self._input_workers, [input_context])
-
- def _configure(self,
- session_config=None,
- cluster_spec=None,
- task_type=None,
- task_id=None):
- """Configures the object.
-
- Args:
- session_config: a `tf.ConfigProto`
- cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
- cluster configurations.
- task_type: the current task type, such as "worker".
- task_id: the current task id.
-
- Raises:
- ValueError: if `task_type` is not in the `cluster_spec`.
- """
- if not self._cluster_spec and cluster_spec:
- # If a `cluster_spec` is already passed in, do nothing here.
- # TODO(yuefengz): check `cluster_spec` is the same if this object has
- # already been initialized with a `cluster_spec`.
- self._initialize_multi_worker(self._num_gpus_per_worker, cluster_spec,
- task_type, task_id)
- assert isinstance(self._get_cross_device_ops(),
- cross_device_ops_lib.CollectiveAllReduce)
-
- if session_config:
- session_config.CopyFrom(self._update_config_proto(session_config))
-
- def _update_config_proto(self, config_proto):
- updated_config = copy.deepcopy(config_proto)
- # Enable the scoped allocator optimization for CollectiveOps. This
- # optimization converts many small all-reduces into fewer larger
- # all-reduces.
- rewrite_options = updated_config.graph_options.rewrite_options
- rewrite_options.scoped_allocator_optimization = (
- rewriter_config_pb2.RewriterConfig.ON)
- # We turn on ScopedAllocator only for CollectiveReduce op, i.e. enable_op =
- # ["CollectiveReduce"]. Since we can't assign to a repeated proto field, we
- # clear and then append.
- del rewrite_options.scoped_allocator_opts.enable_op[:]
- rewrite_options.scoped_allocator_opts.enable_op.append("CollectiveReduce")
-
- if not self._cluster_spec:
- return updated_config
-
- assert self._task_type
- assert self._task_id is not None
-
- # Collective group leader is needed for collective ops to coordinate
- # workers.
- if "chief" in self._cluster_spec.jobs:
- updated_config.experimental.collective_group_leader = (
- "/job:chief/replica:0/task:0")
- else:
- if "worker" not in self._cluster_spec.jobs:
- raise ValueError(
- "You must have `chief` or `worker` jobs in the `cluster_spec`.")
- updated_config.experimental.collective_group_leader = (
- "/job:worker/replica:0/task:0")
-
- # The device filters prevent communication between workers.
- del updated_config.device_filters[:]
- updated_config.device_filters.append(
- "/job:%s/task:%d" % (self._task_type, self._task_id))
-
- return updated_config
-
- @property
- def experimental_between_graph(self):
- return True
-
- @property
- def experimental_should_init(self):
- return True
-
- @property
- def should_checkpoint(self):
- return self._is_chief
-
- @property
- def should_save_summary(self):
- return self._is_chief
-
- @property
- def _num_replicas_in_sync(self):
- return len(self.worker_devices) * self._num_workers
-
- # TODO(priyag): Delete this once all strategies use global batch size.
- @property
- def _global_batch_size(self):
- """`make_dataset_iterator` and `make_numpy_iterator` use global batch size.
-
- `make_input_fn_iterator` assumes per-replica batching.
-
- Returns:
- Boolean.
- """
- return True
+ # Use TFConfigClusterResolver to parse TF_CONFIG. We don't want to change
+ # the constructor's interface to allow customized cluster resolver. Use
+ # SimpleClusterResolver to override num_accelerators.
+ tfconfig = TFConfigClusterResolver()
+ cluster_resolver = SimpleClusterResolver(
+ cluster_spec=tfconfig.cluster_spec(),
+ task_type=tfconfig.task_type,
+ task_id=tfconfig.task_id,
+ num_accelerators=num_gpus_per_worker)
+ super(CollectiveAllReduceExtended, self).__init__(
+ container_strategy, cluster_resolver=cluster_resolver)
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
index bba0bce..acbe467 100644
--- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
@@ -29,9 +29,13 @@
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.distribute import collective_all_reduce_strategy as core_collective_all_reduce_strategy
from tensorflow.python.distribute import cross_device_utils
+from tensorflow.python.distribute import distribute_lib
+from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.distribute import reduce_util
from tensorflow.python.distribute import values
+from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -49,6 +53,55 @@
from tensorflow.python.platform import test
from tensorflow.python.training import adam
from tensorflow.python.training import training_util
+from tensorflow.python.training.server_lib import ClusterSpec
+
+
+class MockCollectiveAllReduceStrategy(distribute_lib.DistributionStrategy):
+ """Mock the strategy to allow cluster resolver as an argument."""
+
+ def __init__(self, cluster_resolver):
+ super(MockCollectiveAllReduceStrategy, self).__init__(
+ core_collective_all_reduce_strategy.CollectiveAllReduceExtended(
+ self, cluster_resolver=cluster_resolver))
+
+
+def create_test_objects(cluster_spec=None,
+ task_type=None,
+ task_id=None,
+ num_gpus=None,
+ use_core_strategy=False):
+ sess_config = config_pb2.ConfigProto()
+ if num_gpus is None:
+ num_gpus = context.num_gpus()
+ if use_core_strategy:
+ if cluster_spec and task_type and task_id is not None:
+ cluster_resolver = SimpleClusterResolver(
+ cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec),
+ task_type=task_type,
+ task_id=task_id,
+ num_accelerators=num_gpus)
+ target = 'grpc://' + cluster_spec[task_type][task_id]
+ else:
+ cluster_resolver = SimpleClusterResolver(
+ ClusterSpec({}), num_accelerators=num_gpus)
+ target = ''
+
+ strategy = MockCollectiveAllReduceStrategy(cluster_resolver)
+ sess_config = strategy.update_config_proto(sess_config)
+ else:
+ strategy = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
+ num_gpus_per_worker=num_gpus)
+ if task_type and task_id is not None:
+ strategy.configure(
+ session_config=sess_config,
+ cluster_spec=cluster_spec,
+ task_type=task_type,
+ task_id=task_id)
+ target = 'grpc://' + cluster_spec[task_type][task_id]
+ else:
+ target = ''
+
+ return strategy, target, sess_config
class CollectiveAllReduceStrategyTestBase(
@@ -64,16 +117,18 @@
CollectiveAllReduceStrategyTestBase.collective_key_base += 100000
super(CollectiveAllReduceStrategyTestBase, self).setUp()
- def _get_test_object(self, task_type, task_id, num_gpus=0):
- distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
- num_gpus_per_worker=num_gpus)
- session_config = config_pb2.ConfigProto()
- if task_type and task_id is not None:
- distribution.configure(
- session_config=session_config,
- cluster_spec=self._cluster_spec,
- task_type=task_type,
- task_id=task_id)
+ def _get_test_object(self,
+ task_type,
+ task_id,
+ num_gpus=0,
+ use_core_strategy=False):
+ strategy, target, session_config = create_test_objects(
+ cluster_spec=self._cluster_spec,
+ task_type=task_type,
+ task_id=task_id,
+ num_gpus=num_gpus,
+ use_core_strategy=use_core_strategy)
+
collective_keys = cross_device_utils.CollectiveKeys(
group_key_start=10 * num_gpus +
CollectiveAllReduceStrategyTestBase.collective_key_base,
@@ -81,16 +136,16 @@
CollectiveAllReduceStrategyTestBase.collective_key_base,
instance_key_with_id_start=num_gpus * 10000 +
CollectiveAllReduceStrategyTestBase.collective_key_base)
- distribution.extended._collective_keys = collective_keys
- distribution.extended._cross_device_ops._collective_keys = (
- collective_keys)
- if task_type and task_id is not None:
- return distribution, 'grpc://' + self._cluster_spec[task_type][
- task_id], session_config
- else:
- return distribution, '', session_config
+ strategy.extended._collective_keys = collective_keys
+ strategy.extended._cross_device_ops._collective_keys = (collective_keys)
- def _test_minimize_loss_graph(self, task_type, task_id, num_gpus):
+ return strategy, target, session_config
+
+ def _test_minimize_loss_graph(self,
+ task_type,
+ task_id,
+ num_gpus,
+ use_core_strategy=False):
d, master_target, config = self._get_test_object(task_type, task_id,
num_gpus)
with ops.Graph().as_default(), \
@@ -158,7 +213,11 @@
self.assertLess(error_after, error_before)
return error_after < error_before
- def _test_complex_model(self, task_type, task_id, num_gpus):
+ def _test_complex_model(self,
+ task_type,
+ task_id,
+ num_gpus,
+ use_core_strategy=False):
d, master_target, config = self._get_test_object(task_type, task_id,
num_gpus)
@@ -210,7 +269,11 @@
sess.run(train_op)
return True
- def _test_variable_initialization(self, task_type, task_id, num_gpus):
+ def _test_variable_initialization(self,
+ task_type,
+ task_id,
+ num_gpus,
+ use_core_strategy=False):
distribution, master_target, config = self._get_test_object(
task_type, task_id, num_gpus)
with ops.Graph().as_default(), \
@@ -239,8 +302,14 @@
reduced_x_value)))
return np.allclose(x_value, reduced_x_value, atol=1e-5)
- def _test_input_fn_iterator(self, task_type, task_id, num_gpus, input_fn,
- expected_values, test_reinitialize=True):
+ def _test_input_fn_iterator(self,
+ task_type,
+ task_id,
+ num_gpus,
+ input_fn,
+ expected_values,
+ test_reinitialize=True,
+ use_core_strategy=False):
distribution, master_target, config = self._get_test_object(
task_type, task_id, num_gpus)
devices = distribution.extended.worker_devices
@@ -284,45 +353,72 @@
cls._cluster_spec = multi_worker_test_base.create_in_process_cluster(
num_workers=3, num_ps=0)
- def test_num_replicas_in_sync(self):
- distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
- num_gpus_per_worker=2)
- distribution.configure(cluster_spec=self._cluster_spec, task_type='worker',
- task_id=0)
+ @combinations.generate(
+ combinations.combine(mode=['graph'], use_core_strategy=[True, False]))
+ def test_num_replicas_in_sync(self, use_core_strategy):
+ distribution, _, _ = create_test_objects(
+ cluster_spec=self._cluster_spec,
+ task_type='worker',
+ task_id=0,
+ num_gpus=2,
+ use_core_strategy=use_core_strategy)
num_workers = len(self._cluster_spec.get('chief', []) +
self._cluster_spec.get('worker', []))
self.assertEqual(2 * num_workers,
distribution.num_replicas_in_sync)
@combinations.generate(
- combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1))
- def testMinimizeLossGraph(self, num_gpus):
- self._run_between_graph_clients(self._test_minimize_loss_graph,
- self._cluster_spec, num_gpus)
+ combinations.combine(
+ mode=['graph'],
+ num_gpus=[0, 1, 2],
+ required_gpus=1,
+ use_core_strategy=[True, False]))
+ def testMinimizeLossGraph(self, num_gpus, use_core_strategy):
+ self._run_between_graph_clients(
+ self._test_minimize_loss_graph,
+ self._cluster_spec,
+ num_gpus,
+ use_core_strategy=use_core_strategy)
@combinations.generate(
- combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1))
- def testVariableInitialization(self, num_gpus):
+ combinations.combine(
+ mode=['graph'],
+ num_gpus=[0, 1, 2],
+ required_gpus=1,
+ use_core_strategy=[True, False]))
+ def testVariableInitialization(self, num_gpus, use_core_strategy):
if context.num_gpus() < num_gpus:
self.skipTest('Not enough GPUs')
self._run_between_graph_clients(
self._test_variable_initialization,
self._cluster_spec,
- num_gpus=num_gpus)
+ num_gpus=num_gpus,
+ use_core_strategy=use_core_strategy)
@combinations.generate(
- combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1))
- def testComplexModel(self, num_gpus):
+ combinations.combine(
+ mode=['graph'],
+ num_gpus=[0, 1, 2],
+ required_gpus=1,
+ use_core_strategy=[True, False]))
+ def testComplexModel(self, num_gpus, use_core_strategy):
if context.num_gpus() < num_gpus:
self.skipTest('Not enough GPUs')
self._run_between_graph_clients(
- self._test_complex_model, self._cluster_spec, num_gpus=num_gpus)
+ self._test_complex_model,
+ self._cluster_spec,
+ num_gpus=num_gpus,
+ use_core_strategy=use_core_strategy)
# TODO(yuefengz): Update how we use num_gpus and required_gpus
@combinations.generate(
- combinations.combine(mode=['graph'], num_gpus=[0, 1, 2], required_gpus=1,
- use_dataset=[True, False]))
- def testMakeInputFnIterator(self, num_gpus, use_dataset):
+ combinations.combine(
+ mode=['graph'],
+ num_gpus=[0, 1, 2],
+ required_gpus=1,
+ use_dataset=[True, False],
+ use_core_strategy=[True, False]))
+ def testMakeInputFnIterator(self, num_gpus, use_dataset, use_core_strategy):
if context.num_gpus() < num_gpus:
self.skipTest('Not enough GPUs')
if use_dataset:
@@ -342,21 +438,29 @@
expected_num_replicas_in_sync=3*devices_per_worker,
expected_num_input_pipelines=3,
expected_input_pipeline_id=1) # because task_id = 1
- self._test_input_fn_iterator('worker', 1, num_gpus,
- input_fn, expected_values,
- test_reinitialize=use_dataset)
+ self._test_input_fn_iterator(
+ 'worker',
+ 1,
+ num_gpus,
+ input_fn,
+ expected_values,
+ test_reinitialize=use_dataset,
+ use_core_strategy=use_core_strategy)
- def testUpdateConfigProto(self):
- distribution = collective_all_reduce_strategy.CollectiveAllReduceStrategy(
- num_gpus_per_worker=2)
- distribution.configure(
- cluster_spec=self._cluster_spec, task_type='worker', task_id=1)
+ @combinations.generate(
+ combinations.combine(mode=['graph'], use_core_strategy=[True, False]))
+ def testUpdateConfigProto(self, use_core_strategy):
+ strategy, _, _ = self._get_test_object(
+ task_type='worker',
+ task_id=1,
+ num_gpus=2,
+ use_core_strategy=use_core_strategy)
config_proto = config_pb2.ConfigProto(device_filters=['to_be_overridden'])
rewrite_options = config_proto.graph_options.rewrite_options
rewrite_options.scoped_allocator_opts.enable_op.append('to_be_removed')
- new_config = distribution.update_config_proto(config_proto)
+ new_config = strategy.update_config_proto(config_proto)
# Verify group leader
self.assertEqual('/job:worker/replica:0/task:0',
@@ -415,28 +519,41 @@
@combinations.generate(
combinations.combine(
- mode=['graph', 'eager'], num_gpus=[2, 4], required_gpus=2))
- def testMinimizeLoss(self, num_gpus):
+ mode=['graph', 'eager'],
+ num_gpus=[2, 4],
+ required_gpus=2,
+ use_core_strategy=[True, False]))
+ def testMinimizeLoss(self, num_gpus, use_core_strategy):
# Collective ops doesn't support strategy with one device.
if context.num_gpus() < num_gpus:
self.skipTest('Not enough GPUs')
if context.executing_eagerly():
- strategy, _, _ = self._get_test_object(None, None, num_gpus)
+ strategy, _, _ = self._get_test_object(
+ None, None, num_gpus, use_core_strategy=use_core_strategy)
self._test_minimize_loss_eager(strategy)
else:
- self._test_minimize_loss_graph(None, None, num_gpus)
+ self._test_minimize_loss_graph(
+ None, None, num_gpus, use_core_strategy=use_core_strategy)
@combinations.generate(
- combinations.combine(mode=['graph'], num_gpus=[2, 4], required_gpus=2))
- def testComplexModel(self, num_gpus):
+ combinations.combine(
+ mode=['graph'],
+ num_gpus=[2, 4],
+ required_gpus=2,
+ use_core_strategy=[True, False]))
+ def testComplexModel(self, num_gpus, use_core_strategy):
if context.num_gpus() < num_gpus:
self.skipTest('Not enough GPUs')
- self._test_complex_model(None, None, num_gpus)
+ self._test_complex_model(
+ None, None, num_gpus, use_core_strategy=use_core_strategy)
@combinations.generate(
- combinations.combine(mode=['graph', 'eager'], required_gpus=2,
- use_dataset=[True, False]))
- def testMakeInputFnIterator(self, use_dataset):
+ combinations.combine(
+ mode=['graph', 'eager'],
+ required_gpus=2,
+ use_dataset=[True, False],
+ use_core_strategy=[True, False]))
+ def testMakeInputFnIterator(self, use_dataset, use_core_strategy):
num_gpus = 2
if use_dataset:
fn = lambda: dataset_ops.Dataset.range(5 * num_gpus)
@@ -452,51 +569,77 @@
expected_num_replicas_in_sync=num_gpus,
expected_num_input_pipelines=1,
expected_input_pipeline_id=0)
- self._test_input_fn_iterator(None, None, num_gpus,
- input_fn, expected_values,
- test_reinitialize=use_dataset)
+ self._test_input_fn_iterator(
+ None,
+ None,
+ num_gpus,
+ input_fn,
+ expected_values,
+ test_reinitialize=use_dataset,
+ use_core_strategy=use_core_strategy)
- def testAllReduceSum(self):
+ @combinations.generate(
+ combinations.combine(mode=['graph'], use_core_strategy=[True, False]))
+ def testAllReduceSum(self, use_core_strategy):
if context.num_gpus() < 2: self.skipTest('Not enough GPUs')
- distribution, target, config = self._get_test_object(None, None, num_gpus=2)
+ distribution, target, config = self._get_test_object(
+ None, None, num_gpus=2, use_core_strategy=use_core_strategy)
with self.cached_session(config=config, target=target):
self._test_all_reduce_sum(distribution)
- def testAllReduceSumGradients(self):
+ @combinations.generate(
+ combinations.combine(mode=['graph'], use_core_strategy=[True, False]))
+ def testAllReduceSumGradients(self, use_core_strategy):
if context.num_gpus() < 2: self.skipTest('Not enough GPUs')
- distribution, target, config = self._get_test_object(None, None, num_gpus=2)
+ distribution, target, config = self._get_test_object(
+ None, None, num_gpus=2, use_core_strategy=use_core_strategy)
with self.cached_session(config=config, target=target):
self._test_all_reduce_sum_gradients(distribution)
- def testAllReduceSumGradientTape(self):
+ @combinations.generate(
+ combinations.combine(mode=['graph'], use_core_strategy=[True, False]))
+ def testAllReduceSumGradientTape(self, use_core_strategy):
if context.num_gpus() < 2: self.skipTest('Not enough GPUs')
- distribution, target, config = self._get_test_object(None, None, num_gpus=2)
+ distribution, target, config = self._get_test_object(
+ None, None, num_gpus=2, use_core_strategy=use_core_strategy)
with self.cached_session(config=config, target=target):
self._test_all_reduce_sum_gradient_tape(distribution)
- def testAllReduceMean(self):
+ @combinations.generate(
+ combinations.combine(mode=['graph'], use_core_strategy=[True, False]))
+ def testAllReduceMean(self, use_core_strategy):
if context.num_gpus() < 2: self.skipTest('Not enough GPUs')
- distribution, target, config = self._get_test_object(None, None, num_gpus=2)
+ distribution, target, config = self._get_test_object(
+ None, None, num_gpus=2, use_core_strategy=use_core_strategy)
with self.cached_session(config=config, target=target):
self._test_all_reduce_mean(distribution)
- def testAllReduceMeanGradients(self):
+ @combinations.generate(
+ combinations.combine(mode=['graph'], use_core_strategy=[True, False]))
+ def testAllReduceMeanGradients(self, use_core_strategy):
if context.num_gpus() < 2: self.skipTest('Not enough GPUs')
- distribution, target, config = self._get_test_object(None, None, num_gpus=2)
+ distribution, target, config = self._get_test_object(
+ None, None, num_gpus=2, use_core_strategy=use_core_strategy)
with self.cached_session(config=config, target=target):
self._test_all_reduce_mean_gradients(distribution)
- def testAllReduceMeanGradientTape(self):
+ @combinations.generate(
+ combinations.combine(mode=['graph'], use_core_strategy=[True, False]))
+ def testAllReduceMeanGradientTape(self, use_core_strategy):
if context.num_gpus() < 2: self.skipTest('Not enough GPUs')
- distribution, target, config = self._get_test_object(None, None, num_gpus=2)
+ distribution, target, config = self._get_test_object(
+ None, None, num_gpus=2, use_core_strategy=use_core_strategy)
with self.cached_session(config=config, target=target):
self._test_all_reduce_mean_gradient_tape(distribution)
- def testNumpyIterator(self):
+ @combinations.generate(
+ combinations.combine(mode=['graph'], use_core_strategy=[True, False]))
+ def testNumpyIterator(self, use_core_strategy):
num_gpus = 2
if context.num_gpus() < num_gpus:
self.skipTest('Not enough GPUs')
- strategy, _, _ = self._get_test_object(None, None, num_gpus)
+ strategy, _, _ = self._get_test_object(
+ None, None, num_gpus=num_gpus, use_core_strategy=use_core_strategy)
self._test_numpy_iterator(strategy)
diff --git a/tensorflow/contrib/distribute/python/combinations.py b/tensorflow/contrib/distribute/python/combinations.py
index 798a159..7c0f803 100644
--- a/tensorflow/contrib/distribute/python/combinations.py
+++ b/tensorflow/contrib/distribute/python/combinations.py
@@ -352,6 +352,9 @@
one_device_strategy = NamedDistribution(
"OneDeviceCPU", lambda: one_device_lib.OneDeviceStrategy("/cpu:0"),
required_gpus=None)
+one_device_strategy_gpu = NamedDistribution(
+ "OneDeviceGPU", lambda: one_device_lib.OneDeviceStrategy("/gpu:0"),
+ required_gpus=1)
tpu_strategy = NamedDistribution(
"TPU", _get_tpu_strategy_creator(steps_per_run=2),
required_tpu=True)
diff --git a/tensorflow/contrib/distribute/python/examples/mnist_eager_multigpu.py b/tensorflow/contrib/distribute/python/examples/mnist_eager_multigpu.py
index 11a3b5e..c045a55 100644
--- a/tensorflow/contrib/distribute/python/examples/mnist_eager_multigpu.py
+++ b/tensorflow/contrib/distribute/python/examples/mnist_eager_multigpu.py
@@ -25,20 +25,23 @@
from __future__ import division
from __future__ import print_function
+from absl import app
+from absl import flags
import numpy as np
-import tensorflow as tf
+import tensorflow.compat.v2 as tf
-tf.flags.DEFINE_integer("num_gpus", None, "How many GPUs should we run on?"
- "Defaults to all available GPUs, otherwise CPU.")
-tf.flags.DEFINE_integer("batch_size", 64,
- "What should be the size of each batch?")
-tf.flags.DEFINE_integer("num_epochs", 10, "How many epochs to run?")
-tf.flags.DEFINE_float("learning_rate", 0.01, "Learning Rate")
-tf.flags.DEFINE_float("momentum", 0.5, "SGD momentum")
+flags.DEFINE_integer("num_gpus", None, "How many GPUs should we run on?"
+ "Defaults to all available GPUs, otherwise CPU.")
+flags.DEFINE_integer("batch_size", 64,
+ "What should be the size of each batch?")
+flags.DEFINE_integer("num_epochs", 10, "How many epochs to run?")
+flags.DEFINE_float("learning_rate", 0.01, "Learning Rate")
+flags.DEFINE_float("momentum", 0.5, "SGD momentum")
+flags.DEFINE_boolean("use_function", False,
+ "Should we wrap the step in a tf.function.")
-FLAGS = tf.flags.FLAGS
+FLAGS = flags.FLAGS
NUM_TRAIN_IMAGES = 60000
-NUM_TEST_IMAGES = 10000
def create_model():
@@ -82,7 +85,7 @@
def main(unused_argv):
"""Run a CNN model on MNIST data to demonstrate DistributedStrategies."""
- tf.enable_eager_execution()
+ tf.enable_v2_behavior()
num_gpus = FLAGS.num_gpus
if num_gpus is None:
@@ -99,7 +102,7 @@
test_ds = test_ds.batch(FLAGS.batch_size)
model = create_model()
- optimizer = tf.train.MomentumOptimizer(FLAGS.learning_rate, FLAGS.momentum)
+ optimizer = tf.keras.optimizers.SGD(FLAGS.learning_rate, FLAGS.momentum)
training_loss = tf.keras.metrics.Mean("training_loss", dtype=tf.float32)
training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
"training_accuracy", dtype=tf.float32)
@@ -126,12 +129,24 @@
train_iterator = strategy.make_dataset_iterator(train_ds)
test_iterator = strategy.make_dataset_iterator(test_ds)
+
for epoch in range(0, FLAGS.num_epochs):
+ # TODO(b/123315763): Create the tf.function outside this loop once we are
+ # able to initialize iterator in eager mode.
+ dist_train = lambda it: strategy.experimental_run(train_step, it)
+ dist_test = lambda it: strategy.experimental_run(test_step, it)
+ if FLAGS.use_function:
+ dist_train = tf.function(dist_train)
+ dist_test = tf.function(dist_test)
+
# Train
print("Starting epoch {}".format(epoch))
train_iterator.initialize()
- for _ in range(NUM_TRAIN_IMAGES // FLAGS.batch_size):
- strategy.experimental_run(train_step, train_iterator)
+ while True:
+ try:
+ dist_train(train_iterator)
+ except tf.errors.OutOfRangeError:
+ break
print("Training loss: {:0.4f}, accuracy: {:0.2f}%".format(
training_loss.result(), training_accuracy.result() * 100))
training_loss.reset_states()
@@ -139,8 +154,11 @@
# Test
test_iterator.initialize()
- for _ in range(NUM_TEST_IMAGES // FLAGS.batch_size):
- strategy.experimental_run(test_step, test_iterator)
+ while True:
+ try:
+ dist_test(test_iterator)
+ except tf.errors.OutOfRangeError:
+ break
print("Test loss: {:0.4f}, accuracy: {:0.2f}%".format(
test_loss.result(), test_accuracy.result() * 100))
test_loss.reset_states()
@@ -148,4 +166,4 @@
if __name__ == "__main__":
- tf.app.run()
+ app.run(main)
diff --git a/tensorflow/contrib/distribute/python/keras_backward_compat_test.py b/tensorflow/contrib/distribute/python/keras_backward_compat_test.py
index b783ab8..9a581e7 100644
--- a/tensorflow/contrib/distribute/python/keras_backward_compat_test.py
+++ b/tensorflow/contrib/distribute/python/keras_backward_compat_test.py
@@ -34,6 +34,7 @@
from tensorflow.python.ops.parsing_ops import gen_parsing_ops
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import rmsprop
+from tensorflow.python.training.mode_keys import ModeKeys
_RANDOM_SEED = 1337
_TRAIN_SIZE = 200
@@ -745,7 +746,9 @@
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
callbacks=[keras.callbacks.LearningRateScheduler(schedule)])
- grouped_models = distribution.unwrap(model._distributed_model_train)
+ grouped_models = distribution.unwrap(
+ distributed_training_utils.get_distributed_model(
+ model, ModeKeys.TRAIN))
with distribution.scope():
for m in grouped_models:
self.assertAllClose(0.001, keras.backend.get_value(
@@ -791,16 +794,21 @@
verbose=0,
sample_weight=sample_weight)
- # Test with not specifying the `steps` argument.
- with self.assertRaisesRegexp(
- ValueError, 'the `steps_per_epoch` argument'):
+ # Test with not specifying the `steps` argument for dataset with
+ # infinite cardinality.
+ dataset = dataset.repeat()
+ with self.assertRaisesRegexp(ValueError, 'When passing an infinitely '
+ 'repeating dataset, you must specify the '
+ '`steps_per_epoch` argument'):
model.fit(dataset, epochs=1, verbose=0)
- with self.assertRaisesRegexp(ValueError,
- 'the `steps` argument'):
+ with self.assertRaisesRegexp(ValueError, 'When passing an infinitely '
+ 'repeating dataset, you must specify the '
+ '`steps` argument'):
model.evaluate(dataset, verbose=0)
- with self.assertRaisesRegexp(ValueError,
- 'the `steps` argument'):
+ with self.assertRaisesRegexp(ValueError, 'When passing an infinitely '
+ 'repeating dataset, you must specify the '
+ '`steps` argument'):
model.predict(dataset, verbose=0)
@combinations.generate(combinations.combine(
diff --git a/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py b/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py
index 952b119..c93d7afa 100644
--- a/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py
+++ b/tensorflow/contrib/distribute/python/keras_optimizer_v2_test.py
@@ -64,7 +64,7 @@
def loss_fn():
replica_id = _replica_id()
- return math_ops.cast(replica_id + 1, dtype=dtypes.float32) * var
+ return math_ops.cast(replica_id + 1, dtype=dtypes.float32) * 0.5 * var
train_op = optimizer.minimize(loss_fn, var_list=[var])
diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py
index 73c27d6..6a15083 100644
--- a/tensorflow/contrib/distribute/python/keras_test.py
+++ b/tensorflow/contrib/distribute/python/keras_test.py
@@ -230,6 +230,8 @@
return model
+# TODO(josh11b): Add combinations.one_device_strategy_gpu once it works with
+# TestDistributionStrategyWithCallbacks.test_callbacks_in_predict.
strategies_minus_tpu = [
combinations.default_strategy,
combinations.one_device_strategy,
@@ -244,15 +246,13 @@
def strategy_minus_tpu_combinations():
- return combinations.combine(
- distribution=strategies_minus_tpu,
- mode=['graph', 'eager'])
+ return combinations.combine(distribution=strategies_minus_tpu,
+ mode=['graph', 'eager'])
def tpu_strategy_combinations():
- return combinations.combine(
- distribution=tpu_strategies,
- mode=['graph'])
+ return combinations.combine(distribution=tpu_strategies,
+ mode=['graph'])
def all_strategy_combinations():
@@ -287,9 +287,9 @@
def strategy_for_numpy_input_combinations():
- return combinations.combine(
- distribution=strategies_minus_tpu + tpu_strategies,
- mode=['graph'])
+ one_gpu = combinations.one_device_strategy_gpu
+ return (all_strategy_combinations() +
+ combinations.combine(distribution=[one_gpu], mode=['graph', 'eager']))
class TestEstimatorDistributionStrategy(test_util.TensorFlowTestCase,
@@ -841,6 +841,42 @@
model.fit(dataset_dict, epochs=1, steps_per_epoch=2, verbose=1)
+ # TODO(b/122743976): Include TPUStrategy for this test as well once
+ # step inference is supported.
+ @combinations.generate(strategy_minus_tpu_combinations())
+ def test_fit_eval_and_predict_methods_on_dataset_without_steps(
+ self, distribution):
+ with self.cached_session():
+ with distribution.scope():
+ model = get_model()
+ optimizer = gradient_descent.GradientDescentOptimizer(0.001)
+ loss = 'mse'
+ metrics = ['mae', keras.metrics.CategoricalAccuracy()]
+ model.compile(optimizer, loss, metrics=metrics)
+
+ inputs = np.zeros((1000, 3), dtype=np.float32)
+ targets = np.zeros((1000, 4), dtype=np.float32)
+ # steps/steps_per_epoch are calculated when using numpy arrays as
+ # input data.
+ fit_with_numpy = model.fit(inputs, targets, epochs=1,
+ batch_size=10).history
+ eval_with_numpy = model.evaluate(inputs, targets, batch_size=10)
+ predict_with_numpy = model.predict(inputs, batch_size=10)
+
+ dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
+ dataset = dataset.batch(10)
+ fit_with_ds = model.fit(dataset, epochs=1).history
+ eval_with_ds = model.evaluate(dataset)
+ predict_dataset = dataset_ops.Dataset.from_tensor_slices(inputs)
+ predict_dataset = predict_dataset.batch(10, drop_remainder=True)
+ predict_with_ds = model.predict(predict_dataset)
+ self.assertAllClose(
+ fit_with_numpy, fit_with_ds, atol=1e-4, rtol=1e-4)
+ self.assertAllClose(
+ eval_with_numpy, eval_with_ds, atol=1e-4, rtol=1e-4)
+ self.assertAllClose(
+ predict_with_numpy, predict_with_ds, atol=1e-4, rtol=1e-4)
+
@combinations.generate(all_strategy_combinations())
def test_fit_eval_and_predict_methods_on_dataset(self, distribution):
with self.cached_session():
@@ -1120,16 +1156,7 @@
class TestDistributionStrategyWithCallbacks(test.TestCase,
parameterized.TestCase):
- def _check_counts(self, counter, expected_counts):
- """Checks that the counts registered by `counter` are those expected."""
- for method_name, expected_count in expected_counts.items():
- self.assertEqual(
- counter.method_counts[method_name],
- expected_count,
- msg='For method {}: expected {}, got: {}'.format(
- method_name, expected_count, counter.method_counts[method_name]))
-
- @combinations.generate(strategy_minus_tpu_combinations())
+ @combinations.generate(all_strategy_combinations())
def test_callbacks_in_fit(self, distribution):
with distribution.scope():
model = get_model()
@@ -1138,36 +1165,46 @@
dataset = get_dataset(distribution)
counter = Counter()
+ epochs = 2
+ steps_per_epoch = 5
+ validation_steps = 3
+
model.fit(
dataset,
- epochs=2,
- steps_per_epoch=5,
+ epochs=epochs,
+ steps_per_epoch=steps_per_epoch,
verbose=0,
validation_data=dataset,
- validation_steps=2,
+ validation_steps=validation_steps,
callbacks=[counter])
- self._check_counts(
- counter, {
- 'on_batch_begin': 10,
- 'on_batch_end': 10,
- 'on_epoch_begin': 2,
- 'on_epoch_end': 2,
- 'on_predict_batch_begin': 0,
- 'on_predict_batch_end': 0,
- 'on_predict_begin': 0,
- 'on_predict_end': 0,
- 'on_test_batch_begin': 4,
- 'on_test_batch_end': 4,
- 'on_test_begin': 2,
- 'on_test_end': 2,
- 'on_train_batch_begin': 10,
- 'on_train_batch_end': 10,
+ if isinstance(distribution, tpu_strategy.TPUStrategy):
+ # TPU Strategy can have multi step training, from extended.steps_per_run
+ # if steps_per_run = 1, then num_batch_call_per_epoch = steps_per_epoch
+ steps_per_run = distribution.extended.steps_per_run
+ num_batch_call_per_epoch = steps_per_epoch // steps_per_run
+ if steps_per_epoch % steps_per_run:
+ num_batch_call_per_epoch += 1
+ else:
+ num_batch_call_per_epoch = steps_per_epoch
+
+ self.assertDictEqual(
+ counter.method_counts, {
+ 'on_batch_begin': epochs * num_batch_call_per_epoch,
+ 'on_batch_end': epochs * num_batch_call_per_epoch,
+ 'on_epoch_begin': epochs,
+ 'on_epoch_end': epochs,
+ 'on_test_batch_begin': epochs * validation_steps,
+ 'on_test_batch_end': epochs * validation_steps,
+ 'on_test_begin': epochs,
+ 'on_test_end': epochs,
+ 'on_train_batch_begin': epochs * num_batch_call_per_epoch,
+ 'on_train_batch_end': epochs * num_batch_call_per_epoch,
'on_train_begin': 1,
'on_train_end': 1
})
- @combinations.generate(strategy_minus_tpu_combinations())
+ @combinations.generate(all_strategy_combinations())
def test_callbacks_in_eval(self, distribution):
with distribution.scope():
model = get_model()
@@ -1178,15 +1215,15 @@
model.evaluate(dataset, steps=5, callbacks=[counter])
- self._check_counts(
- counter, {
+ self.assertDictEqual(
+ counter.method_counts, {
'on_test_batch_begin': 5,
'on_test_batch_end': 5,
'on_test_begin': 1,
'on_test_end': 1
})
- @combinations.generate(strategy_minus_tpu_combinations())
+ @combinations.generate(all_strategy_combinations())
def test_callbacks_in_predict(self, distribution):
with distribution.scope():
model = get_model()
@@ -1197,8 +1234,8 @@
model.predict(get_predict_dataset(dataset), steps=5, callbacks=[counter])
- self._check_counts(
- counter, {
+ self.assertDictEqual(
+ counter.method_counts, {
'on_predict_batch_begin': 5,
'on_predict_batch_end': 5,
'on_predict_begin': 1,
@@ -1293,14 +1330,21 @@
verbose=0,
sample_weight=sample_weight)
- # Test with not specifying the `steps` argument.
- with self.assertRaisesRegexp(
- ValueError, 'the `steps_per_epoch` argument'):
+ # Test with not specifying the `steps` argument for dataset with infinite
+ # cardinality.
+ dataset = dataset.repeat()
+ with self.assertRaisesRegexp(ValueError, 'When passing an infinitely '
+ 'repeating dataset, you must specify the '
+ '`steps_per_epoch` argument'):
model.fit(dataset, epochs=1, verbose=0)
- with self.assertRaisesRegexp(ValueError, 'the `steps` argument'):
+ with self.assertRaisesRegexp(ValueError, 'When passing an infinitely '
+ 'repeating dataset, you must specify the '
+ '`steps` argument'):
model.evaluate(dataset, verbose=0)
- with self.assertRaisesRegexp(ValueError, 'the `steps` argument'):
+ with self.assertRaisesRegexp(ValueError, 'When passing an infinitely '
+ 'repeating dataset, you must specify the '
+ '`steps` argument'):
model.predict(dataset, verbose=0)
@combinations.generate(combinations.combine(
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index 0b8df78..bc0572b 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -606,6 +606,7 @@
aggregation="invalid")
def testNonMatchingVariableCreation(self, distribution):
+ self.skipTest("b/123075960")
def model_fn(name):
v = variable_scope.variable(1.0, name=name)
ds_context.get_replica_context().merge_call(lambda _: _)
diff --git a/tensorflow/contrib/distribute/python/one_device_strategy_test.py b/tensorflow/contrib/distribute/python/one_device_strategy_test.py
index 906bffc..93c2447 100644
--- a/tensorflow/contrib/distribute/python/one_device_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/one_device_strategy_test.py
@@ -18,6 +18,8 @@
from __future__ import division
from __future__ import print_function
+import sys
+
from tensorflow.contrib.distribute.python import strategy_test_lib
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import one_device_strategy
@@ -29,8 +31,12 @@
strategy_test_lib.DistributionTestBase,
strategy_test_lib.OneDeviceDistributionTestBase):
+ # TODO(josh11b): Switch to using the combinations library.
def _get_distribution_strategy(self):
- return one_device_strategy.OneDeviceStrategy("/device:CPU:0")
+ if "test_gpu" in sys.argv[0]:
+ return one_device_strategy.OneDeviceStrategy("/device:GPU:0")
+ else:
+ return one_device_strategy.OneDeviceStrategy("/device:CPU:0")
def testMinimizeLossEager(self):
self._test_minimize_loss_eager(self._get_distribution_strategy())
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index 69ce114..341d9ae 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -155,8 +155,7 @@
def __init__(self,
tpu_cluster_resolver=None,
steps_per_run=None,
- device_assignment=None,
- **kwargs):
+ device_assignment=None):
"""Initializes the TPUStrategy object.
Args:
@@ -170,18 +169,9 @@
device_assignment: Optional `tf.contrib.tpu.DeviceAssignment` to specify
the placement of replicas on the TPU cluster. Currently only supports
the usecase of using a single core within a TPU cluster.
- **kwargs: Additional experimental flags. Will be removed in future.
"""
- if len(kwargs) > 1:
- raise ValueError("TPUStrategy constructor only takes one experimental "
- "flag now")
- elif len(kwargs) == 1 and "_disable_training_loop_on_host" not in kwargs:
- raise ValueError("TPUStrategy constructor does not support arguments: "
- "{}".format(kwargs))
-
super(TPUStrategy, self).__init__(TPUExtended(
- self, tpu_cluster_resolver, steps_per_run, device_assignment,
- kwargs.get("_disable_training_loop_on_host", False)))
+ self, tpu_cluster_resolver, steps_per_run, device_assignment))
@property
def steps_per_run(self):
@@ -196,11 +186,6 @@
if context.executing_eagerly():
raise NotImplementedError("Eager mode not supported in TPUStrategy.")
- if self.extended._disable_training_loop_on_host: # pylint: disable=protected-access
- raise NotImplementedError(
- "`experimental_run` is not compatible with "
- "`_disable_training_loop_on_host=True`")
-
if input_iterator is None:
inputs = []
else:
@@ -241,8 +226,7 @@
container_strategy,
tpu_cluster_resolver=None,
steps_per_run=None,
- device_assignment=None,
- disable_training_loop_on_host=False):
+ device_assignment=None):
super(TPUExtended, self).__init__(container_strategy)
if tpu_cluster_resolver is None:
@@ -256,7 +240,6 @@
self._tpu_cluster_resolver = tpu_cluster_resolver
self._tpu_metadata = get_tpu_system_metadata(self._tpu_cluster_resolver)
self._device_assignment = device_assignment
- self._disable_training_loop_on_host = disable_training_loop_on_host
# Device assignment is currently only supported for 1 core case.
if self._device_assignment:
@@ -284,25 +267,14 @@
self._tpu_devices = self._tpu_devices[:self._num_replicas_in_sync]
self._device_map = values.ReplicaDeviceMap(self._tpu_devices)
- # If the training loop is on the device, we must use the infeed, with input
- # on the host. Otherwise, we preload the data onto the TPUs.
- if disable_training_loop_on_host:
- input_device_map = values.ReplicaDeviceMap(tuple(
- self.get_host_cpu_device(hid) for hid in range(self.num_hosts)))
- worker_devices = [
- (self.get_host(hid), [self.get_host_cpu_device(hid)])
- for hid in range(self.num_hosts)
- ]
- self._input_workers = input_lib.InputWorkers(
- input_device_map, worker_devices)
- else:
- input_worker_devices = collections.OrderedDict()
- for tpu_device in self._tpu_devices:
- host_device = _get_host_for_device(tpu_device)
- input_worker_devices.setdefault(host_device, [])
- input_worker_devices[host_device].append(tpu_device)
- self._input_workers = input_lib.InputWorkers(
- self._device_map, tuple(input_worker_devices.items()))
+ # Preload the data onto the TPUs.
+ input_worker_devices = collections.OrderedDict()
+ for tpu_device in self._tpu_devices:
+ host_device = _get_host_for_device(tpu_device)
+ input_worker_devices.setdefault(host_device, [])
+ input_worker_devices[host_device].append(tpu_device)
+ self._input_workers = input_lib.InputWorkers(
+ self._device_map, tuple(input_worker_devices.items()))
# TODO(sourabhbajaj): Remove this once performance of running one step
# at a time is comparable to multiple steps.
@@ -402,17 +374,6 @@
# a mechanism to infer the outputs of `fn`. Pending b/110550782.
def _experimental_run_steps_on_iterator(
self, fn, multi_worker_iterator, iterations, initial_loop_values=None):
- if self._disable_training_loop_on_host:
- impl = self._run_steps_on_iterator_with_device_loop
- else:
- impl = self._run_steps_on_iterator_with_host_loop
-
- return impl(
- fn=fn, multi_worker_iterator=multi_worker_iterator,
- iterations=iterations, initial_loop_values=initial_loop_values)
-
- def _run_steps_on_iterator_with_host_loop(
- self, fn, multi_worker_iterator, iterations, initial_loop_values=None):
output_shapes = multi_worker_iterator.output_shapes
shapes = nest.flatten(output_shapes)
if any(not s.is_fully_defined() for s in shapes):
@@ -507,79 +468,6 @@
_set_last_step_outputs(ctx, last_step_tensor_outputs)
return ctx
- def _run_steps_on_iterator_with_device_loop(
- self, fn, multi_worker_iterator, iterations, initial_loop_values=None):
- output_shapes = multi_worker_iterator.output_shapes
- shapes = nest.flatten(output_shapes)
- if any(not s.is_fully_defined() for s in shapes):
- raise ValueError(
- "TPU currently requires fully defined shapes. Either use "
- "set_shape() on the input tensors or use "
- "dataset.batch(..., drop_remainder=True).")
- types = nest.flatten(multi_worker_iterator.output_types)
-
- enqueue_ops = [
- self._get_enqueue_op_per_host(host_id, multi_worker_iterator, shapes,
- iterations)
- for host_id in range(self.num_hosts)]
-
- def dequeue_fn():
- dequeued = tpu_ops.infeed_dequeue_tuple(dtypes=types, shapes=shapes)
- return nest.pack_sequence_as(output_shapes, dequeued)
-
- # Wrap `fn` for repeat.
- if initial_loop_values is None:
- initial_loop_values = {}
- initial_loop_values = nest.flatten(initial_loop_values)
- ctx = input_lib.MultiStepContext()
-
- def run_fn(*args, **kwargs):
- """Single step on the TPU device."""
- del args, kwargs
- fn_result = fn(ctx, dequeue_fn())
- flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
- if flat_last_step_outputs:
- with ops.control_dependencies([fn_result]):
- return [array_ops.identity(f) for f in flat_last_step_outputs]
- else:
- return fn_result
-
- def iterate_on_tpu():
- return training_loop.repeat(iterations, run_fn, initial_loop_values)
-
- # We capture the control_flow_context at this point, before we run `fn`
- # inside a while_loop and TPU replicate context. This is useful in cases
- # where we might need to exit these contexts and get back to the outer
- # context to do some things, for e.g. create an op which should be
- # evaluated only once at the end of the loop on the host. One such usage
- # is in creating metrics' value op.
- self._outer_control_flow_context = (
- ops.get_default_graph()._get_control_flow_context()) # pylint: disable=protected-access
-
- replicate_inputs = [[]] * self._num_replicas_in_sync
- replicate_outputs = tpu.replicate(iterate_on_tpu, replicate_inputs)
-
- del self._outer_control_flow_context
- ctx.run_op = control_flow_ops.group(replicate_outputs, enqueue_ops)
-
- # Filter out any ops from the outputs, typically this would be the case
- # when there were no tensor outputs.
- last_step_tensor_outputs = [x for x in replicate_outputs
- if not isinstance(x, ops.Operation)]
-
- # Outputs are currently of the structure (grouped by device)
- # [[output0_device0, output1_device0, output2_device0],
- # [output0_device1, output1_device1, output2_device1]]
- # Convert this to the following structure instead: (grouped by output)
- # [[output0_device0, output0_device1],
- # [output1_device0, output1_device1],
- # [output2_device0, output2_device1]]
- last_step_tensor_outputs = [list(x) for x in
- zip(*last_step_tensor_outputs)]
-
- _set_last_step_outputs(ctx, last_step_tensor_outputs)
- return ctx
-
def _call_for_each_replica(self, fn, args, kwargs):
# TODO(jhseu): Consider making it so call_for_each_replica implies that
# we're in a tpu.rewrite(), and update TPUMirroredVariable accordingly.
@@ -655,19 +543,24 @@
return cross_device_ops_lib.reduce_non_distributed_value(
reduce_op, self._device_map, value, destinations)
- # Validate that the destination is same as the host device
- # Note we don't do this when in replicate context as the reduction is
- # performed on the TPU device itself.
devices = cross_device_ops_lib.get_devices_from(destinations)
- if len(devices) == 1:
- assert device_util.canonicalize(devices[0]) == device_util.canonicalize(
- self._host_device)
- else:
+ if len(devices) != 1:
raise ValueError("Multiple devices are not supported for TPUStrategy")
- output = math_ops.add_n(value)
- if reduce_op == reduce_util.ReduceOp.MEAN:
- return output * (1. / len(value))
+ # Always performs the reduction on the TPU host.
+ with ops.device(self._host_device):
+ output = math_ops.add_n(value.values)
+ if reduce_op == reduce_util.ReduceOp.MEAN:
+ output *= (1. / len(value.values))
+
+ # If necessary, copy to requested destination.
+ dest_canonical = device_util.canonicalize(devices[0])
+ host_canonical = device_util.canonicalize(self._host_device)
+
+ if dest_canonical != host_canonical:
+ with ops.device(devices[0]):
+ output = array_ops.identity(output)
+
return output
def _update(self, var, fn, args, kwargs, group):
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index 3079175..c230028 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -822,7 +822,7 @@
cuda_py_test(
name = "affine_test",
- size = "large",
+ size = "medium",
srcs = ["python/kernel_tests/bijectors/affine_test.py"],
additional_deps = [
":bijectors_py",
@@ -837,7 +837,7 @@
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
- shard_count = 5,
+ shard_count = 10,
tags = ["noasan"], # times out b/63678675
)
diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py
index b82e1bb..12bbdc0 100644
--- a/tensorflow/contrib/eager/python/tfe.py
+++ b/tensorflow/contrib/eager/python/tfe.py
@@ -62,7 +62,6 @@
@@Checkpoint
@@Checkpointable
-@@CheckpointableSaver
@@executing_eagerly
@@in_eager_mode
@@ -139,7 +138,6 @@
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import template
from tensorflow.python.training.checkpointable.tracking import AutoCheckpointable as Checkpointable
-from tensorflow.python.training.checkpointable.util import CheckpointableSaver
from tensorflow.python.training.checkpointable.util import Checkpoint
from tensorflow.python.util.all_util import remove_undocumented
diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD
index cb86efb..48a6ef4 100644
--- a/tensorflow/contrib/factorization/BUILD
+++ b/tensorflow/contrib/factorization/BUILD
@@ -109,7 +109,7 @@
# Ops tests
tf_py_test(
name = "gmm_test",
- size = "large",
+ size = "medium",
srcs = [
"python/ops/gmm_test.py",
],
@@ -130,6 +130,7 @@
"//tensorflow/python:random_seed",
"//tensorflow/python:training",
],
+ shard_count = 4,
tags = [
"no_pip", # b/38283730
"notsan", # Flaky: b/30756419
@@ -227,7 +228,7 @@
tf_py_test(
name = "wals_test",
- size = "large",
+ size = "medium",
srcs = ["python/ops/wals_test.py"],
additional_deps = [
":factorization_py",
@@ -250,8 +251,8 @@
"//tensorflow/python:training",
"//tensorflow/python:variables",
],
+ shard_count = 4,
tags = [
- "manual",
"noasan", # times out b/63678675
"nomsan",
],
diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD
index 3f6dbe0..c99f847 100644
--- a/tensorflow/contrib/framework/BUILD
+++ b/tensorflow/contrib/framework/BUILD
@@ -179,6 +179,7 @@
additional_deps = [
"//tensorflow/python:client_testlib",
":framework_py",
+ "//tensorflow/python/data/experimental/ops:prefetching_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_for_generated_wrappers",
diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py
index 3784631..fc2334d 100644
--- a/tensorflow/contrib/framework/__init__.py
+++ b/tensorflow/contrib/framework/__init__.py
@@ -139,6 +139,7 @@
'map_structure_with_tuple_paths',
'assert_shallow_structure',
'flatten_up_to',
+ 'flatten_with_tuple_paths_up_to',
'map_structure_up_to',
'map_structure_with_tuple_paths_up_to',
'get_traverse_shallow_structure',
diff --git a/tensorflow/contrib/framework/python/ops/critical_section_test.py b/tensorflow/contrib/framework/python/ops/critical_section_test.py
index 34fd501..d2bb4f4 100644
--- a/tensorflow/contrib/framework/python/ops/critical_section_test.py
+++ b/tensorflow/contrib/framework/python/ops/critical_section_test.py
@@ -19,6 +19,7 @@
from __future__ import print_function
from tensorflow.contrib.framework.python.ops import critical_section_ops
+from tensorflow.python.data.experimental.ops import prefetching_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
@@ -334,13 +335,22 @@
@test_util.run_in_graph_and_eager_modes
def testInsideFunction(self):
+ if test_util.is_gpu_available():
+ self.skipTest(
+ "b/123899495: Colocation errors for critical sections in map on GPU")
cs = critical_section_ops.CriticalSection()
- v = resource_variable_ops.ResourceVariable(1)
+ with ops.device("/gpu:0" if test_util.is_gpu_available() else "/cpu:0"):
+ v = resource_variable_ops.ResourceVariable(1)
def fn():
return v.read_value()
# map() creates a TensorFlow function.
- ds = dataset_ops.Dataset.range(1).map(lambda _: cs.execute(fn))
+ ds = dataset_ops.Dataset.range(1)
+ if test_util.is_gpu_available():
+ ds = (ds.apply(prefetching_ops.copy_to_device("/gpu:0"))
+ .apply(prefetching_ops.map_on_gpu(lambda _: cs.execute(fn))))
+ else:
+ ds = ds.map(lambda _: cs.execute(fn))
def get_first():
if context.executing_eagerly():
diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
index e3c780a..44ee0f5 100644
--- a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
+++ b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
@@ -403,7 +403,9 @@
def test_all_correct(self):
loss = self._penalty_fn(**self._kwargs)
self.assertEqual(self._expected_dtype, loss.dtype)
- self.assertEqual(self._expected_op_name, loss.op.name)
+ # NOTE: Op names will change, it is inappropriate to include them in tests.
+ # See go/tf-breaking-change.
+ # self.assertEqual(self._expected_op_name, loss.op.name)
with self.cached_session():
variables.global_variables_initializer().run()
self.assertAlmostEqual(self._expected_loss, loss.eval(), 6)
diff --git a/tensorflow/contrib/mpi/mpi_server_lib.cc b/tensorflow/contrib/mpi/mpi_server_lib.cc
index a31fa9c..e44e10a 100644
--- a/tensorflow/contrib/mpi/mpi_server_lib.cc
+++ b/tensorflow/contrib/mpi/mpi_server_lib.cc
@@ -54,7 +54,10 @@
Status MPIServer::Init(ServiceInitFunction service_func,
RendezvousMgrCreationFunction rendezvous_mgr_func) {
- Status s = GrpcServer::Init(service_func, rendezvous_mgr_func);
+ GrpcServerOptions opts;
+ opts.service_func = service_func;
+ opts.rendezvous_mgr_func = rendezvous_mgr_func;
+ Status s = GrpcServer::Init(opts);
return s;
}
diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
index 0243927..b5de726 100644
--- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
+++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
@@ -44,6 +44,7 @@
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as core_saver
from tensorflow.python.training import training_util
+from tensorflow.python.training.checkpointable import graph_view
from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.training.checkpointable import util
@@ -118,9 +119,8 @@
self.evaluate(util.gather_initializers(
root_checkpointable))
self.evaluate(train_op)
- named_variables, serialized_graph, _ = (
- util._serialize_object_graph(
- root_checkpointable, saveables_cache=None))
+ named_variables, serialized_graph, _ = graph_view.ObjectGraphView(
+ root_checkpointable).serialize_object_graph()
expected_checkpoint_names = (
# Created in the root node, so no prefix.
"optimizer_step",
@@ -440,7 +440,7 @@
def testDeferredSlotRestoration(self):
checkpoint_directory = self.get_temp_dir()
- root = tracking.AutoCheckpointable()
+ root = util.Checkpoint()
root.var = util.add_variable(
root, name="var", initializer=0.)
optimizer = adam.AdamOptimizer(0.1)
@@ -455,21 +455,17 @@
util.Checkpoint(root=root, optimizer=optimizer)))
self.evaluate(train_op)
self.evaluate(state_ops.assign(root.var, 12.))
- no_slots_path = util.CheckpointableSaver(root).save(
- os.path.join(checkpoint_directory, "no_slots"))
+ no_slots_path = root.save(os.path.join(checkpoint_directory, "no_slots"))
root.optimizer = optimizer
self.evaluate(state_ops.assign(root.var, 13.))
self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var),
14.))
- slots_path = util.CheckpointableSaver(root).save(
- os.path.join(checkpoint_directory, "with_slots"))
- new_root = tracking.AutoCheckpointable()
+ slots_path = root.save(os.path.join(checkpoint_directory, "with_slots"))
+ new_root = util.Checkpoint()
# Load the slot-containing checkpoint (deferred), then immediately overwrite
# the non-slot variable (also deferred).
- slot_status = util.CheckpointableSaver(
- new_root).restore(slots_path)
- no_slot_status = util.CheckpointableSaver(
- new_root).restore(no_slots_path)
+ slot_status = new_root.restore(slots_path)
+ no_slot_status = new_root.restore(no_slots_path)
with self.assertRaises(AssertionError):
no_slot_status.assert_consumed()
new_root.var = util.add_variable(
@@ -508,15 +504,14 @@
with graph.as_default(), self.session(graph):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
- obj = tracking.AutoCheckpointable()
+ obj = util.Checkpoint()
obj.var = variable_scope.get_variable(name="v", initializer=0.)
obj.opt = adam.AdamOptimizer(0.1)
obj.opt.minimize(obj.var.read_value())
self.evaluate(util.gather_initializers(obj))
- saver = util.CheckpointableSaver(obj)
- saver.save(checkpoint_prefix)
+ obj.save(checkpoint_prefix)
before_ops = graph.get_operations()
- saver.save(checkpoint_prefix)
+ obj.save(checkpoint_prefix)
self.assertEqual(before_ops, graph.get_operations())
def testManyRestoresGraph(self):
@@ -526,16 +521,15 @@
with graph.as_default(), self.session(graph):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
- obj = tracking.AutoCheckpointable()
+ obj = util.Checkpoint()
obj.var = variable_scope.get_variable(name="v", initializer=0.)
obj.opt = adam.AdamOptimizer(0.1)
obj.opt.minimize(obj.var.read_value())
self.evaluate(util.gather_initializers(obj))
- saver = util.CheckpointableSaver(obj)
- save_path = saver.save(checkpoint_prefix)
- saver.restore(save_path)
+ save_path = obj.save(checkpoint_prefix)
+ obj.restore(save_path)
before_ops = graph.get_operations()
- saver.restore(save_path)
+ obj.restore(save_path)
self.assertEqual(before_ops, graph.get_operations())
def testMultipleGraphsNonSlotVariables(self):
@@ -704,7 +698,7 @@
self._set_sentinels(root)
with self.assertRaises(AssertionError):
self._check_sentinels(root)
- object_saver = util.CheckpointableSaver(root)
+ object_saver = util.CheckpointableSaver(graph_view.ObjectGraphView(root))
self._set_sentinels(root)
status = object_saver.restore(save_path)
if context.executing_eagerly():
diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
index 1323ed0..a49149e 100644
--- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py
+++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
@@ -24,7 +24,6 @@
import six
-from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx
from tensorflow.python.distribute import reduce_util as ds_reduce_util
from tensorflow.python.eager import backprop
@@ -661,7 +660,7 @@
name=None,
grad_loss=None,
stop_gradients=None,
- scale_loss_by_num_replicas=None):
+ scale_loss_by_num_replicas=False):
"""Add operations to minimize `loss` by updating `var_list`.
This method simply combines calls `compute_gradients()` and
@@ -685,8 +684,7 @@
stop_gradients: Optional. A Tensor or list of tensors not to differentiate
through.
scale_loss_by_num_replicas: Optional boolean. If true, scale the loss down
- by the number of replicas. By default, auto-detects whether this is
- needed.
+ by the number of replicas. DEPRECATED and generally no longer needed.
Returns:
An Operation that updates the variables in `var_list`. If `global_step`
@@ -732,7 +730,7 @@
aggregation_method=None,
grad_loss=None,
stop_gradients=None,
- scale_loss_by_num_replicas=None):
+ scale_loss_by_num_replicas=False):
"""Compute gradients of `loss` for the variables in `var_list`.
This is the first part of `minimize()`. It returns a list
@@ -756,8 +754,7 @@
stop_gradients: Optional. A Tensor or list of tensors not to differentiate
through.
scale_loss_by_num_replicas: Optional boolean. If true, scale the loss down
- by the number of replicas. By default, auto-detects whether this is
- needed.
+ by the number of replicas. DEPRECATED and generally no longer needed.
Returns:
A list of (gradient, variable) pairs. Variable is always present, but
@@ -781,9 +778,7 @@
tape.watch(var_list)
loss_value = loss()
- # Scale loss for number of replicas (callable-loss case). In this case,
- # we have to be careful to call distribute_lib.get_loss_reduction()
- # *after* loss() is evaluated, so we know what loss reduction it uses.
+ # Scale loss for number of replicas (callable-loss case).
loss_value = self._scale_loss(loss_value, scale_loss_by_num_replicas)
if var_list is None:
@@ -839,9 +834,6 @@
@staticmethod
def _scale_loss(loss_value, scale_loss_by_num_replicas):
"""Scale loss for the number of replicas."""
- if scale_loss_by_num_replicas is None:
- scale_loss_by_num_replicas = (
- distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN)
if scale_loss_by_num_replicas:
num_replicas = distribute_ctx.get_strategy().num_replicas_in_sync
if num_replicas > 1:
diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py b/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py
index dd7f2f4..2fc0b5e 100644
--- a/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py
+++ b/tensorflow/contrib/optimizer_v2/optimizer_v2_test.py
@@ -26,7 +26,7 @@
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
-from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import gradients_util
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
@@ -71,7 +71,7 @@
opt_op = sgd_op.minimize(
cost,
global_step, [var0, var1],
- aggregation_method=gradients_impl.AggregationMethod.
+ aggregation_method=gradients_util.AggregationMethod.
EXPERIMENTAL_ACCUMULATE_N)
variables.global_variables_initializer().run()
diff --git a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py
index 17b69c7..c8524e9 100644
--- a/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py
+++ b/tensorflow/contrib/proto/python/kernel_tests/decode_proto_op_test_base.py
@@ -84,7 +84,10 @@
values = field_dict[field.name]
self.assertEqual(dtypes.as_dtype(values.dtype), field.dtype)
- fd = field.value.DESCRIPTOR.fields_by_name[field.name]
+ if 'ext_value' in field.name:
+ fd = test_example_pb2.PrimitiveValue()
+ else:
+ fd = field.value.DESCRIPTOR.fields_by_name[field.name]
# Values has the same shape as the input plus an extra
# dimension for repeats.
@@ -92,13 +95,16 @@
# Nested messages are represented as TF strings, requiring
# some special handling.
- if field.name == 'message_value':
+ if field.name == 'message_value' or 'ext_value' in field.name:
vs = []
for buf in values.flat:
msg = test_example_pb2.PrimitiveValue()
msg.ParseFromString(buf)
vs.append(msg)
- evs = getattr(field.value, field.name)
+ if 'ext_value' in field.name:
+ evs = field.value.Extensions[test_example_pb2.ext_value]
+ else:
+ evs = getattr(field.value, field.name)
if len(vs) != len(evs):
self.fail('Field %s decoded %d outputs, expected %d' %
(fd.name, len(vs), len(evs)))
@@ -223,7 +229,8 @@
sanitize=False,
force_disordered=True)
- @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
+ @parameterized.named_parameters(
+ *test_base.ProtoOpTestBase.named_parameters(extension=False))
def testPacked(self, case):
# Now try with the packed serialization.
#
@@ -235,8 +242,7 @@
# Note: float_format='.17g' is necessary to ensure preservation of
# doubles and floats in text format.
text_format.Parse(
- text_format.MessageToString(
- value, float_format='.17g'),
+ text_format.MessageToString(value, float_format='.17g'),
test_example_pb2.PackedTestValue()).SerializeToString()
for value in case.values
]
diff --git a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py
index 01b3ccc..5ec681f 100644
--- a/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py
+++ b/tensorflow/contrib/proto/python/kernel_tests/encode_proto_op_test_base.py
@@ -15,9 +15,6 @@
# =============================================================================
"""Table-driven test for encode_proto op.
-This test is run once with each of the *.TestCase.pbtxt files
-in the test directory.
-
It tests that encode_proto is a lossless inverse of decode_proto
(for the specified fields).
"""
@@ -145,7 +142,8 @@
# loss of packing in the encoding).
self.assertEqual(in_buf, out_buf)
- @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
+ @parameterized.named_parameters(
+ *test_base.ProtoOpTestBase.named_parameters(extension=False))
def testRoundtrip(self, case):
in_bufs = [value.SerializeToString() for value in case.values]
@@ -154,7 +152,8 @@
return self._testRoundtrip(
in_bufs, 'tensorflow.contrib.proto.TestValue', case.fields)
- @parameterized.named_parameters(*test_base.ProtoOpTestBase.named_parameters())
+ @parameterized.named_parameters(
+ *test_base.ProtoOpTestBase.named_parameters(extension=False))
def testRoundtripPacked(self, case):
# Now try with the packed serialization.
# We test the packed representations by loading the same test cases using
diff --git a/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py b/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py
index 2950c7d..1a63648 100644
--- a/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py
+++ b/tensorflow/contrib/proto/python/kernel_tests/proto_op_test_base.py
@@ -38,17 +38,18 @@
ct.cdll.LoadLibrary(lib)
@staticmethod
- def named_parameters():
- return (
- ("defaults", ProtoOpTestBase.defaults_test_case()),
- ("minmax", ProtoOpTestBase.minmax_test_case()),
- ("nested", ProtoOpTestBase.nested_test_case()),
- ("optional", ProtoOpTestBase.optional_test_case()),
- ("promote", ProtoOpTestBase.promote_test_case()),
- ("ragged", ProtoOpTestBase.ragged_test_case()),
- ("shaped_batch", ProtoOpTestBase.shaped_batch_test_case()),
- ("simple", ProtoOpTestBase.simple_test_case()),
- )
+ def named_parameters(extension=True):
+ parameters = [("defaults", ProtoOpTestBase.defaults_test_case()),
+ ("minmax", ProtoOpTestBase.minmax_test_case()),
+ ("nested", ProtoOpTestBase.nested_test_case()),
+ ("optional", ProtoOpTestBase.optional_test_case()),
+ ("promote", ProtoOpTestBase.promote_test_case()),
+ ("ragged", ProtoOpTestBase.ragged_test_case()),
+ ("shaped_batch", ProtoOpTestBase.shaped_batch_test_case()),
+ ("simple", ProtoOpTestBase.simple_test_case())]
+ if extension:
+ parameters.append(("extension", ProtoOpTestBase.extension_test_case()))
+ return parameters
@staticmethod
def defaults_test_case():
@@ -400,6 +401,21 @@
return test_case
@staticmethod
+ def extension_test_case():
+ test_case = test_example_pb2.TestCase()
+ value = test_case.values.add()
+ message_value = value.Extensions[test_example_pb2.ext_value].add()
+ message_value.double_value = 23.5
+ test_case.shapes.append(1)
+ test_case.sizes.append(1)
+ field = test_case.fields.add()
+ field.name = test_example_pb2.ext_value.full_name
+ field.dtype = types_pb2.DT_STRING
+ message_value = field.value.Extensions[test_example_pb2.ext_value].add()
+ message_value.double_value = 23.5
+ return test_case
+
+ @staticmethod
def simple_test_case():
test_case = test_example_pb2.TestCase()
value = test_case.values.add()
diff --git a/tensorflow/contrib/proto/python/kernel_tests/test_example.proto b/tensorflow/contrib/proto/python/kernel_tests/test_example.proto
index 674d881..b1ce66d 100644
--- a/tensorflow/contrib/proto/python/kernel_tests/test_example.proto
+++ b/tensorflow/contrib/proto/python/kernel_tests/test_example.proto
@@ -61,6 +61,8 @@
optional sfixed64 sfixed64_value_with_default = 32 [default = 11];
optional sint32 sint32_value_with_default = 33 [default = 12];
optional sint64 sint64_value_with_default = 34 [default = 13];
+
+ extensions 100 to 199;
}
// A PackedTestValue looks exactly the same as a TestValue in the text format,
@@ -68,7 +70,7 @@
// by loading the same test cases using this definition instead of TestValue.
//
// NOTE: This definition must be kept in sync with TestValue in every way except
-// the packed=true declaration.
+// the packed=true declaration and the lack of extensions.
message PackedTestValue {
repeated double double_value = 1 [packed = true];
repeated float float_value = 2 [packed = true];
@@ -132,6 +134,10 @@
optional bool bool_value = 1777;
}
+extend TestValue {
+ repeated PrimitiveValue ext_value = 100;
+}
+
// The messages below are for yet-to-be created tests.
message EnumValue {
diff --git a/tensorflow/contrib/quantize/README.md b/tensorflow/contrib/quantize/README.md
index 5b8da92..b335e1a 100644
--- a/tensorflow/contrib/quantize/README.md
+++ b/tensorflow/contrib/quantize/README.md
@@ -8,9 +8,9 @@
For efficient inference, TensorFlow combines batch normalization with the preceding
convolutional and fully-connected layers prior to quantization by
-[folding batch norm layers](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/quantize/python/fold_batch_norms.py){:.external}.
+[folding batch norm layers](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/quantize/python/fold_batch_norms.py){:.external}.
-The quantization error is modeled using [fake quantization](../api_guides/python/array_ops.md#Fake_quantization)
+The quantization error is modeled using [fake quantization](../../api_guides/python/array_ops.md#Fake_quantization)
nodes to simulate the effect of quantization in the forward and backward passes. The
forward-pass models quantization, while the backward-pass models quantization as a
straight-through estimator. Both the forward- and backward-pass simulate the quantization
@@ -105,7 +105,7 @@
--std_value=127.5 --mean_value=127.5
```
-See the documentation for `tf.contrib.quantize` and [TensorFlow Lite](../lite/).
+See the documentation for `tf.contrib.quantize` and [TensorFlow Lite](../../lite/).
## Quantized accuracy results
diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py
index e0c6da0..a70f748 100644
--- a/tensorflow/contrib/quantize/python/fold_batch_norms.py
+++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py
@@ -454,7 +454,7 @@
strides=layer_op.get_attr('strides'),
padding=layer_op.get_attr('padding'),
use_cudnn_on_gpu=layer_op.get_attr('use_cudnn_on_gpu'),
- data_format=layer_op.get_attr('data_format'),
+ data_format=layer_op.get_attr('data_format').decode(),
name=new_layer_name)
elif layer_op.type == 'MatMul':
return math_ops.matmul(
@@ -867,7 +867,7 @@
strides=op.get_attr('strides'),
padding=op.get_attr('padding'),
use_cudnn_on_gpu=op.get_attr('use_cudnn_on_gpu'),
- data_format=op.get_attr('data_format'),
+ data_format=op.get_attr('data_format').decode(),
name=new_name).op
def _CloneDepthwiseConv2d(self, op, inputs, new_name):
diff --git a/tensorflow/contrib/quantize/python/quant_ops.py b/tensorflow/contrib/quantize/python/quant_ops.py
index 8619708..39082ca 100644
--- a/tensorflow/contrib/quantize/python/quant_ops.py
+++ b/tensorflow/contrib/quantize/python/quant_ops.py
@@ -224,8 +224,8 @@
None, default_name=name_prefix, values=[inputs], reuse=reuse) as scope:
scope.set_partitioner(None)
input_shape = inputs.get_shape()
- input_dim = len(input_shape)
if per_channel:
+ input_dim = len(input_shape)
# Only support quantizing 1-, 2- and 4-dimensional tensors.
assert input_dim in [1, 2, 4], ('Expected 1D, 2D or 4D input, was: %s in '
' scope: %s' % (input_shape, name_prefix))
diff --git a/tensorflow/contrib/quantize/python/quant_ops_test.py b/tensorflow/contrib/quantize/python/quant_ops_test.py
index 36d2af9..c636c90 100644
--- a/tensorflow/contrib/quantize/python/quant_ops_test.py
+++ b/tensorflow/contrib/quantize/python/quant_ops_test.py
@@ -63,6 +63,12 @@
self.assertAlmostEqual(min_value, -0.5, delta=1e-3)
self.assertAlmostEqual(max_value, 0.5, delta=1e-3)
+ def testMovingAvgQuantizeTrainingAssignNoShape(self):
+ min_value, max_value = self._GetMinMaxValues(
+ quant_ops.MovingAvgQuantize, [[-1, 1], [0, 0]], shape=None)
+ self.assertAlmostEqual(min_value, -0.5, delta=1e-3)
+ self.assertAlmostEqual(max_value, 0.5, delta=1e-3)
+
def testMovingAvgSymmetricQuantizeTrainingAssign(self):
min_value, max_value = self._GetMinMaxValues(
quant_ops.MovingAvgQuantize, [[-1, 0.5], [0, 0]], symmetric=True)
@@ -109,10 +115,10 @@
is_training=True,
vars_collection=_MIN_MAX_VARS)
- def _GetMinMaxValues(self, quantize_fn, input_values, **kwds):
+ def _GetMinMaxValues(self, quantize_fn, input_values, shape=(2), **kwds):
g = ops.Graph()
with session.Session(graph=g) as sess:
- x = array_ops.placeholder(dtypes.float32, shape=[2])
+ x = array_ops.placeholder(dtypes.float32, shape=shape)
y = quantize_fn(
x,
init_min=0.0,
diff --git a/tensorflow/contrib/remote_fused_graph/pylib/python/ops/remote_fused_graph_ops.py b/tensorflow/contrib/remote_fused_graph/pylib/python/ops/remote_fused_graph_ops.py
index 2054367..7e79785 100644
--- a/tensorflow/contrib/remote_fused_graph/pylib/python/ops/remote_fused_graph_ops.py
+++ b/tensorflow/contrib/remote_fused_graph/pylib/python/ops/remote_fused_graph_ops.py
@@ -50,13 +50,13 @@
if default_graph_input_tensor_type_shapes:
for type_shape in default_graph_input_tensor_type_shapes:
type_shape_proto = info_proto.default_graph_input_tensor_shape.add()
- type_shape_proto.dtype = int(dtypes.as_dtype(type_shape[0]))
+ type_shape_proto.dtype = dtypes.as_dtype(type_shape[0]).as_datatype_enum
for dim in type_shape[1]:
type_shape_proto.shape.dim.add().size = dim
if default_graph_output_tensor_type_shapes:
for type_shape in default_graph_output_tensor_type_shapes:
type_shape_proto = info_proto.default_graph_output_tensor_shape.add()
- type_shape_proto.dtype = int(dtypes.as_dtype(type_shape[0]))
+ type_shape_proto.dtype = dtypes.as_dtype(type_shape[0]).as_datatype_enum
for dim in type_shape[1]:
type_shape_proto.shape.dim.add().size = dim
diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
index 7bad4a6..a70e806 100644
--- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
+++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py
@@ -48,7 +48,7 @@
Linear = core_rnn_cell._Linear # pylint: disable=invalid-name
-class RNNCellTest(test.TestCase):
+class RNNCellTest(test.TestCase, parameterized.TestCase):
def testLinear(self):
with self.cached_session() as sess:
@@ -642,58 +642,54 @@
# The numbers in results were not calculated, this is just a smoke test.
self.assertAllClose(res[0], [[0.154605, 0.154605, 0.154605]])
- def testResidualWrapper(self):
- with self.cached_session() as sess:
- with variable_scope.variable_scope(
- "root", initializer=init_ops.constant_initializer(0.5)):
- x = array_ops.zeros([1, 3])
- m = array_ops.zeros([1, 3])
- base_cell = rnn_cell_impl.GRUCell(3)
- g, m_new = base_cell(x, m)
- variable_scope.get_variable_scope().reuse_variables()
- wrapper_object = rnn_cell_impl.ResidualWrapper(base_cell)
- (name, dep), = wrapper_object._checkpoint_dependencies
- wrapper_object.get_config() # Should not throw an error
- self.assertIs(dep, base_cell)
- self.assertEqual("cell", name)
+ @parameterized.parameters(
+ [rnn_cell_impl.ResidualWrapper, rnn_cell_impl.ResidualWrapperV2])
+ @test_util.run_in_graph_and_eager_modes
+ def testResidualWrapper(self, wrapper_type):
+ x = ops.convert_to_tensor(np.array([[1., 1., 1.]]))
+ m = ops.convert_to_tensor(np.array([[0.1, 0.1, 0.1]]))
+ base_cell = rnn_cell_impl.GRUCell(
+ 3, kernel_initializer=init_ops.constant_initializer(0.5),
+ bias_initializer=init_ops.constant_initializer(0.5))
+ g, m_new = base_cell(x, m)
+ wrapper_object = wrapper_type(base_cell)
+ (name, dep), = wrapper_object._checkpoint_dependencies
+ wrapper_object.get_config() # Should not throw an error
+ self.assertIs(dep, base_cell)
+ self.assertEqual("cell", name)
- g_res, m_new_res = wrapper_object(x, m)
- sess.run([variables_lib.global_variables_initializer()])
- res = sess.run([g, g_res, m_new, m_new_res], {
- x: np.array([[1., 1., 1.]]),
- m: np.array([[0.1, 0.1, 0.1]])
- })
- # Residual connections
- self.assertAllClose(res[1], res[0] + [1., 1., 1.])
- # States are left untouched
- self.assertAllClose(res[2], res[3])
+ g_res, m_new_res = wrapper_object(x, m)
+ self.evaluate([variables_lib.global_variables_initializer()])
+ res = self.evaluate([g, g_res, m_new, m_new_res])
+ # Residual connections
+ self.assertAllClose(res[1], res[0] + [1., 1., 1.])
+ # States are left untouched
+ self.assertAllClose(res[2], res[3])
- def testResidualWrapperWithSlice(self):
- with self.cached_session() as sess:
- with variable_scope.variable_scope(
- "root", initializer=init_ops.constant_initializer(0.5)):
- x = array_ops.zeros([1, 5])
- m = array_ops.zeros([1, 3])
- base_cell = rnn_cell_impl.GRUCell(3)
- g, m_new = base_cell(x, m)
- variable_scope.get_variable_scope().reuse_variables()
+ @parameterized.parameters(
+ [rnn_cell_impl.ResidualWrapper, rnn_cell_impl.ResidualWrapperV2])
+ @test_util.run_in_graph_and_eager_modes
+ def testResidualWrapperWithSlice(self, wrapper_type):
+ x = ops.convert_to_tensor(np.array([[1., 1., 1., 1., 1.]]))
+ m = ops.convert_to_tensor(np.array([[0.1, 0.1, 0.1]]))
+ base_cell = rnn_cell_impl.GRUCell(
+ 3, kernel_initializer=init_ops.constant_initializer(0.5),
+ bias_initializer=init_ops.constant_initializer(0.5))
+ g, m_new = base_cell(x, m)
- def residual_with_slice_fn(inp, out):
- inp_sliced = array_ops.slice(inp, [0, 0], [-1, 3])
- return inp_sliced + out
+ def residual_with_slice_fn(inp, out):
+ inp_sliced = array_ops.slice(inp, [0, 0], [-1, 3])
+ return inp_sliced + out
- g_res, m_new_res = rnn_cell_impl.ResidualWrapper(
- base_cell, residual_with_slice_fn)(x, m)
- sess.run([variables_lib.global_variables_initializer()])
- res_g, res_g_res, res_m_new, res_m_new_res = sess.run(
- [g, g_res, m_new, m_new_res], {
- x: np.array([[1., 1., 1., 1., 1.]]),
- m: np.array([[0.1, 0.1, 0.1]])
- })
- # Residual connections
- self.assertAllClose(res_g_res, res_g + [1., 1., 1.])
- # States are left untouched
- self.assertAllClose(res_m_new, res_m_new_res)
+ g_res, m_new_res = wrapper_type(
+ base_cell, residual_with_slice_fn)(x, m)
+ self.evaluate([variables_lib.global_variables_initializer()])
+ res_g, res_g_res, res_m_new, res_m_new_res = self.evaluate(
+ [g, g_res, m_new, m_new_res])
+ # Residual connections
+ self.assertAllClose(res_g_res, res_g + [1., 1., 1.])
+ # States are left untouched
+ self.assertAllClose(res_m_new, res_m_new_res)
def testDeviceWrapper(self):
with variable_scope.variable_scope(
@@ -836,7 +832,98 @@
self.assertAllClose(res[0], [[0.175991, 0.175991]])
self.assertAllClose(res[1], [[0.13248, 0.13248]])
+ @parameterized.parameters(
+ [[rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2],
+ [rnn_cell_impl.ResidualWrapper, rnn_cell_impl.ResidualWrapperV2]])
+ @test_util.run_in_graph_and_eager_modes
+ def testWrapperKerasStyle(self, wrapper, wrapper_v2):
+ """Tests if wrapper cell is instantiated in keras style scope."""
+ wrapped_cell_v2 = wrapper_v2(rnn_cell_impl.BasicRNNCell(1))
+ self.assertTrue(wrapped_cell_v2._keras_style)
+ wrapped_cell = wrapper(rnn_cell_impl.BasicRNNCell(1))
+ self.assertFalse(wrapped_cell._keras_style)
+
+ @parameterized.parameters(
+ [rnn_cell_impl.DropoutWrapperV2, rnn_cell_impl.ResidualWrapperV2])
+ @test_util.run_in_graph_and_eager_modes
+ def testWrapperV2VariableNames(self, wrapper):
+ """Tests that variables names do not depend on wrapper in RNN layer."""
+
+ def _rnn_input(apply_wrapper, name):
+ """Creates a RNN layer with/without wrapper and returns built rnn cell."""
+ with base_layer.keras_style_scope():
+ base_cell = rnn_cell_impl.MultiRNNCell(
+ [rnn_cell_impl.BasicRNNCell(1, name="basic_rnn_cell")
+ for _ in range(2)])
+ if apply_wrapper:
+ rnn_cell = wrapper(base_cell)
+ else:
+ rnn_cell = base_cell
+ rnn_layer = keras_layers.RNN(rnn_cell, name=name)
+ inputs = ops.convert_to_tensor([[[1]]], dtype=dtypes.float32)
+ _ = rnn_layer(inputs)
+ return base_cell._cells[0]
+
+ rnn_1 = _rnn_input(True, name="rnn_0")
+ rnn_2 = _rnn_input(False, name="rnn_1")
+
+ for i, cell in enumerate([rnn_1, rnn_2]):
+ var_prefix = "rnn_{}/cell_0/basic_rnn_cell/".format(i)
+ self.assertCountEqual([v.name for v in cell.weights],
+ (var_prefix + "kernel:0", var_prefix + "bias:0"))
+
+ @parameterized.parameters(
+ [rnn_cell_impl.DropoutWrapperV2, rnn_cell_impl.ResidualWrapperV2])
+ @test_util.run_in_graph_and_eager_modes
+ def testWrapperWeights(self, wrapper):
+ """Tests that wrapper weights contain wrapped cells weights."""
+
+ with base_layer.keras_style_scope():
+ base_cell = rnn_cell_impl.BasicRNNCell(1, name="basic_rnn_cell")
+ rnn_cell = wrapper(base_cell)
+ rnn_layer = keras_layers.RNN(rnn_cell)
+ inputs = ops.convert_to_tensor([[[1]]], dtype=dtypes.float32)
+ rnn_layer(inputs)
+
+ expected_weights = ["rnn/" + var for var in ("kernel:0", "bias:0")]
+ self.assertEqual(len(rnn_cell.weights), 2)
+ self.assertCountEqual([v.name for v in rnn_cell.weights], expected_weights)
+ self.assertCountEqual([v.name for v in rnn_cell.trainable_variables],
+ expected_weights)
+ self.assertCountEqual([v.name for v in rnn_cell.non_trainable_variables],
+ [])
+ self.assertCountEqual([v.name for v in rnn_cell._cell.weights],
+ expected_weights)
+
+ @parameterized.parameters(
+ [rnn_cell_impl.DropoutWrapperV2, rnn_cell_impl.ResidualWrapperV2])
+ @test_util.run_in_graph_and_eager_modes
+ def testWrapperV2Caller(self, wrapper):
+ """Tests that wrapper V2 is using the LayerRNNCell's caller."""
+
+ with base_layer.keras_style_scope():
+ base_cell = rnn_cell_impl.MultiRNNCell(
+ [rnn_cell_impl.BasicRNNCell(1) for _ in range(2)])
+ rnn_cell = wrapper(base_cell)
+ inputs = ops.convert_to_tensor([[1]], dtype=dtypes.float32)
+ state = ops.convert_to_tensor([[1]], dtype=dtypes.float32)
+ _ = rnn_cell(inputs, [state, state])
+ weights = base_cell._cells[0].weights
+ self.assertLen(weights, expected_len=2)
+ self.assertTrue(all(["_wrapper" in v.name for v in weights]))
+
+ @parameterized.parameters(
+ [rnn_cell_impl.DropoutWrapperV2, rnn_cell_impl.ResidualWrapperV2])
+ @test_util.run_in_graph_and_eager_modes
+ def testWrapperV2Build(self, wrapper):
+ cell = rnn_cell_impl.LSTMCell(10)
+ wrapper = wrapper(cell)
+ wrapper.build((1,))
+ self.assertTrue(cell.built)
+
+
+@test_util.run_all_in_graph_and_eager_modes
class DropoutWrapperTest(test.TestCase, parameterized.TestCase):
def _testDropoutWrapper(self,
@@ -844,39 +931,38 @@
time_steps=None,
parallel_iterations=None,
wrapper_type=None,
+ scope="root",
**kwargs):
- with self.cached_session() as sess:
- with variable_scope.variable_scope(
- "root", initializer=init_ops.constant_initializer(0.5)):
- if batch_size is None and time_steps is None:
- # 2 time steps, batch size 1, depth 3
- batch_size = 1
- time_steps = 2
- x = constant_op.constant(
- [[[2., 2., 2.]], [[1., 1., 1.]]], dtype=dtypes.float32)
- m = rnn_cell_impl.LSTMStateTuple(
- *[constant_op.constant([[0.1, 0.1, 0.1]], dtype=dtypes.float32
- )] * 2)
- else:
- x = constant_op.constant(
- np.random.randn(time_steps, batch_size, 3).astype(np.float32))
- m = rnn_cell_impl.LSTMStateTuple(*[
- constant_op.
- constant([[0.1, 0.1, 0.1]] * batch_size, dtype=dtypes.float32)
- ] * 2)
- outputs, final_state = rnn.dynamic_rnn(
- cell=wrapper_type(
- rnn_cell_impl.LSTMCell(3), dtype=x.dtype, **kwargs),
- time_major=True,
- parallel_iterations=parallel_iterations,
- inputs=x,
- initial_state=m)
- sess.run([variables_lib.global_variables_initializer()])
- res = sess.run([outputs, final_state])
- self.assertEqual(res[0].shape, (time_steps, batch_size, 3))
- self.assertEqual(res[1].c.shape, (batch_size, 3))
- self.assertEqual(res[1].h.shape, (batch_size, 3))
- return res
+ if batch_size is None and time_steps is None:
+ # 2 time steps, batch size 1, depth 3
+ batch_size = 1
+ time_steps = 2
+ x = constant_op.constant(
+ [[[2., 2., 2.]], [[1., 1., 1.]]], dtype=dtypes.float32)
+ m = rnn_cell_impl.LSTMStateTuple(
+ *[constant_op.constant([[0.1, 0.1, 0.1]], dtype=dtypes.float32)] * 2)
+ else:
+ x = constant_op.constant(
+ np.random.randn(time_steps, batch_size, 3).astype(np.float32))
+ m = rnn_cell_impl.LSTMStateTuple(*[
+ constant_op.
+ constant([[0.1, 0.1, 0.1]] * batch_size, dtype=dtypes.float32)] * 2)
+ outputs, final_state = rnn.dynamic_rnn(
+ cell=wrapper_type(
+ rnn_cell_impl.LSTMCell(
+ 3, initializer=init_ops.constant_initializer(0.5)),
+ dtype=x.dtype, **kwargs),
+ time_major=True,
+ parallel_iterations=parallel_iterations,
+ inputs=x,
+ initial_state=m,
+ scope=scope)
+ self.evaluate([variables_lib.global_variables_initializer()])
+ res = self.evaluate([outputs, final_state])
+ self.assertEqual(res[0].shape, (time_steps, batch_size, 3))
+ self.assertEqual(res[1].c.shape, (batch_size, 3))
+ self.assertEqual(res[1].h.shape, (batch_size, 3))
+ return res
@parameterized.parameters(
[rnn_cell_impl.DropoutWrapper, rnn_cell_impl.DropoutWrapperV2])
@@ -946,10 +1032,8 @@
state_keep_prob=keep_some,
seed=10,
parallel_iterations=1,
- wrapper_type=wrapper_type)
- # Clear away the graph and the test session (which keeps variables around)
- ops.reset_default_graph()
- self._ClearCachedSession()
+ wrapper_type=wrapper_type,
+ scope="root_1")
random_seed.set_random_seed(2)
res_standard_2 = self._testDropoutWrapper(
input_keep_prob=keep_some,
@@ -957,7 +1041,8 @@
state_keep_prob=keep_some,
seed=10,
parallel_iterations=1,
- wrapper_type=wrapper_type)
+ wrapper_type=wrapper_type,
+ scope="root_2")
self.assertAllClose(res_standard_1[0], res_standard_2[0])
self.assertAllClose(res_standard_1[1].c, res_standard_2[1].c)
self.assertAllClose(res_standard_1[1].h, res_standard_2[1].h)
@@ -1091,9 +1176,8 @@
input_size=3,
batch_size=5,
time_steps=7,
- seed=-234987)
- ops.reset_default_graph()
- self._ClearCachedSession()
+ seed=-234987,
+ scope="root_0")
random_seed.set_random_seed(2347)
np.random.seed(23487)
res1 = self._testDropoutWrapper(
@@ -1105,7 +1189,8 @@
input_size=3,
batch_size=5,
time_steps=7,
- seed=-234987)
+ seed=-234987,
+ scope="root_1")
output_mask = np.abs(res0[0]) > 1e-6
for time_step in output_mask:
@@ -1128,60 +1213,6 @@
self.assertAllClose(res0[1].c, res1[1].c)
self.assertAllClose(res0[1].h, res1[1].h)
- def testDropoutWrapperKerasStyle(self):
- """Tests if DropoutWrapperV2 cell is instantiated in keras style scope."""
- wrapped_cell_v2 = rnn_cell_impl.DropoutWrapperV2(
- rnn_cell_impl.BasicRNNCell(1))
- self.assertTrue(wrapped_cell_v2._keras_style)
-
- wrapped_cell = rnn_cell_impl.DropoutWrapper(rnn_cell_impl.BasicRNNCell(1))
- self.assertFalse(wrapped_cell._keras_style)
-
- def testDropoutWrapperV2VariableNames(self):
- """Tests that variables names do not depend on wrapper in RNN layer."""
-
- def _rnn_input(apply_wrapper):
- """Creates a RNN layer with/without wrapper and returns built rnn cell."""
- with base_layer.keras_style_scope():
- base_cell = rnn_cell_impl.MultiRNNCell(
- [rnn_cell_impl.BasicRNNCell(1) for _ in range(2)])
- if apply_wrapper:
- rnn_cell = rnn_cell_impl.DropoutWrapperV2(base_cell)
- else:
- rnn_cell = base_cell
- rnn_layer = keras_layers.RNN(rnn_cell)
- inputs = ops.convert_to_tensor([[[1]]], dtype=dtypes.float32)
- _ = rnn_layer(inputs)
- return base_cell._cells[0]
-
- rnn_1 = _rnn_input(True)
- ops.reset_default_graph()
- rnn_2 = _rnn_input(False)
-
- self.assertLen(rnn_1.weights, expected_len=2)
- self.assertCountEqual([v.name for v in rnn_1.weights],
- [v.name for v in rnn_2.weights])
-
- def testDropoutWrapperV2Caller(self):
- """Tests that DropoutWrapperV2 is using the LayerRNNCell's caller."""
-
- with base_layer.keras_style_scope():
- base_cell = rnn_cell_impl.MultiRNNCell(
- [rnn_cell_impl.BasicRNNCell(1) for _ in range(2)])
- rnn_cell = rnn_cell_impl.DropoutWrapperV2(base_cell)
- inputs = ops.convert_to_tensor([[1]], dtype=dtypes.float32)
- state = ops.convert_to_tensor([[1]], dtype=dtypes.float32)
- _ = rnn_cell(inputs, [state, state])
- weights = base_cell._cells[0].weights
- self.assertLen(weights, expected_len=2)
- self.assertTrue(all(["dropout_wrapper" in v.name for v in weights]))
-
- def testDropoutWrapperV2Build(self):
- cell = rnn_cell_impl.LSTMCell(10)
- wrapper = rnn_cell_impl.DropoutWrapperV2(cell)
- wrapper.build((1,))
- self.assertTrue(cell.built)
-
def basic_rnn_cell(inputs, state, num_units, scope=None):
if state is None:
diff --git a/tensorflow/contrib/rnn/python/ops/rnn.py b/tensorflow/contrib/rnn/python/ops/rnn.py
index 0266b72..41b1698 100644
--- a/tensorflow/contrib/rnn/python/ops/rnn.py
+++ b/tensorflow/contrib/rnn/python/ops/rnn.py
@@ -131,7 +131,8 @@
sequence_length=None,
parallel_iterations=None,
time_major=False,
- scope=None):
+ scope=None,
+ swap_memory=False):
"""Creates a dynamic bidirectional recurrent neural network.
Stacks several bidirectional rnn layers. The combined forward and backward
@@ -171,6 +172,10 @@
data is batch-major, so by default this function accepts input and emits
output in batch-major form.
scope: VariableScope for the created subgraph; defaults to None.
+ swap_memory: Transparently swap the tensors produced in forward inference
+ but needed for back prop from GPU to CPU. This allows training RNNs
+ which would typically not fit on a single GPU, with very minimal (or no)
+ performance penalty.
Returns:
A tuple (outputs, output_state_fw, output_state_bw) where:
@@ -230,6 +235,7 @@
sequence_length=sequence_length,
parallel_iterations=parallel_iterations,
dtype=dtype,
+ swap_memory=swap_memory,
time_major=time_major)
# Concat the outputs to create the new input.
prev_layer = array_ops.concat(outputs, 2)
diff --git a/tensorflow/contrib/slim/python/slim/nets/BUILD b/tensorflow/contrib/slim/python/slim/nets/BUILD
index 8bbdf96..e9595d1 100644
--- a/tensorflow/contrib/slim/python/slim/nets/BUILD
+++ b/tensorflow/contrib/slim/python/slim/nets/BUILD
@@ -115,9 +115,9 @@
py_test(
name = "inception_v1_test",
- size = "large",
+ size = "medium",
srcs = ["inception_v1_test.py"],
- shard_count = 3,
+ shard_count = 8,
srcs_version = "PY2AND3",
deps = [
":inception_v1",
@@ -135,9 +135,9 @@
py_test(
name = "inception_v2_test",
- size = "large",
+ size = "medium",
srcs = ["inception_v2_test.py"],
- shard_count = 3,
+ shard_count = 8,
srcs_version = "PY2AND3",
deps = [
":inception_v2",
@@ -155,9 +155,9 @@
py_test(
name = "inception_v3_test",
- size = "large",
+ size = "medium",
srcs = ["inception_v3_test.py"],
- shard_count = 3,
+ shard_count = 8,
srcs_version = "PY2AND3",
deps = [
":inception_v3",
@@ -233,8 +233,9 @@
py_test(
name = "resnet_v1_test",
- size = "large",
+ size = "medium",
srcs = ["resnet_v1_test.py"],
+ shard_count = 4,
srcs_version = "PY2AND3",
deps = [
":resnet_utils",
@@ -268,8 +269,9 @@
py_test(
name = "resnet_v2_test",
- size = "large",
+ size = "medium",
srcs = ["resnet_v2_test.py"],
+ shard_count = 4,
srcs_version = "PY2AND3",
deps = [
":resnet_utils",
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index e13edd1..91b6d26 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -11,22 +11,14 @@
load(
"//tensorflow:tensorflow.bzl",
- "tf_copts",
"tf_cuda_library",
"tf_custom_op_library_additional_deps",
)
-load("//tensorflow:tensorflow.bzl", "cuda_py_test")
-load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
-load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
load(
"@local_config_tensorrt//:build_defs.bzl",
"if_tensorrt",
)
-exports_files(glob([
- "test/testdata/*",
-]))
-
tf_cuda_library(
name = "trt_shape_function",
srcs = ["shape_fn/trt_shfn.cc"],
@@ -45,151 +37,11 @@
srcs = [
"__init__.py",
"python/__init__.py",
+ "python/trt_convert.py",
],
srcs_version = "PY2AND3",
deps = [
- ":tf_trt_integration_test_base",
- ":trt_convert_py",
- ":trt_ops_py",
- "//tensorflow/python:errors",
- ],
-)
-
-py_library(
- name = "trt_ops_py",
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/compiler/tf2tensorrt:trt_ops",
- "//tensorflow/compiler/tf2tensorrt:trt_ops_loader",
- ],
-)
-
-py_library(
- name = "trt_convert_py",
- srcs = ["python/trt_convert.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":wrap_conversion",
- "//tensorflow/python:graph_util",
- "//tensorflow/python:session",
- "//tensorflow/python:tf_optimizer",
- "//tensorflow/python/saved_model:builder",
- "//tensorflow/python/saved_model:loader",
- "//tensorflow/python/saved_model:tag_constants",
- ],
-)
-
-# TODO(aaroey): this wrapper has been causing troubles of double linking, so
-# either get rid of it, or split to make it contain minimum dependencies.
-tf_py_wrap_cc(
- name = "wrap_conversion",
- srcs = ["trt_conversion.i"],
- copts = tf_copts(),
- swig_includes = [
- "//tensorflow/python:platform/base.i",
- ],
- deps = [
- "//tensorflow/compiler/tf2tensorrt:test_utils",
- "//tensorflow/compiler/tf2tensorrt:trt_conversion",
- "//tensorflow/compiler/tf2tensorrt:trt_op_kernels",
- "//third_party/python_runtime:headers",
- ],
-)
-
-py_library(
- name = "tf_trt_integration_test_base",
- srcs = ["test/tf_trt_integration_test_base.py"],
- deps = [
- ":trt_convert_py",
- ":trt_ops_py",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_test_lib",
- ],
-)
-
-cuda_py_test(
- name = "trt_convert_test",
- srcs = ["python/trt_convert_test.py"],
- additional_deps = [
- ":trt_convert_py",
- ":trt_ops_py",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:graph_util",
- "//tensorflow/python/saved_model:builder",
- "//tensorflow/python/saved_model:loader",
- "//tensorflow/python/saved_model:signature_constants",
- "//tensorflow/python/saved_model:signature_def_utils",
- "//tensorflow/python/saved_model:tag_constants",
- "//tensorflow/python/saved_model:utils",
- "//tensorflow/python/tools:freeze_graph_lib",
- "//tensorflow/python/tools:saved_model_utils",
- ],
- tags = [
- "no_cuda_on_cpu_tap",
- "no_windows",
- "nomac",
- ],
-)
-
-cuda_py_tests(
- name = "tf_trt_integration_test",
- srcs = [
- "test/base_test.py",
- "test/batch_matmul_test.py",
- "test/biasadd_matmul_test.py",
- "test/binary_tensor_weight_broadcast_test.py",
- "test/concatenation_test.py",
- "test/const_broadcast_test.py",
- "test/conv2d_test.py",
- "test/dynamic_input_shapes_test.py",
- "test/identity_output_test.py",
- "test/int32_test.py",
- "test/lru_cache_test.py",
- "test/memory_alignment_test.py",
- "test/multi_connection_neighbor_engine_test.py",
- "test/neighboring_engine_test.py",
- "test/quantization_test.py",
- "test/rank_two_test.py",
- "test/reshape_transpose_test.py",
- "test/topk_test.py",
- "test/unary_test.py",
- "test/vgg_block_nchw_test.py",
- "test/vgg_block_test.py",
- ],
- additional_deps = [
- ":tf_trt_integration_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_test_lib",
- ],
- tags = [
- "no_cuda_on_cpu_tap",
- "no_windows",
- "nomac",
- ],
-)
-
-cuda_py_test(
- name = "quantization_mnist_test",
- srcs = ["test/quantization_mnist_test.py"],
- additional_deps = [
- ":tf_trt_integration_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python/keras:keras",
- "//tensorflow/python/estimator:estimator",
- ],
- data = [
- "test/testdata/checkpoint",
- "test/testdata/model.ckpt-46900.data-00000-of-00001",
- "test/testdata/model.ckpt-46900.index",
- ],
- tags = [
- "no_cuda_on_cpu_tap",
- "no_pip",
- "no_tap", # It is not able to download the mnist data.
- "no_windows",
- "nomac",
+ "//tensorflow/python/compiler/tensorrt:init_py",
],
)
diff --git a/tensorflow/contrib/tensorrt/__init__.py b/tensorflow/contrib/tensorrt/__init__.py
index 140ad48..fd551d70 100644
--- a/tensorflow/contrib/tensorrt/__init__.py
+++ b/tensorflow/contrib/tensorrt/__init__.py
@@ -18,18 +18,6 @@
from __future__ import division
from __future__ import print_function
-from tensorflow.python.framework import errors
-
-# pylint: disable=unused-import,wildcard-import,g-import-not-at-top
-try:
- from tensorflow.contrib.tensorrt.python import *
-except errors.NotFoundError as e:
- no_trt_message = (
- '**** Failed to initialize TensorRT. This is either because the TensorRT'
- ' installation path is not in LD_LIBRARY_PATH, or because you do not have'
- ' it installed. If not installed, please go to'
- ' https://developer.nvidia.com/tensorrt to download and install'
- ' TensorRT ****')
- print(no_trt_message)
- raise e
-# pylint: enable=unused-import,wildcard-import,g-import-not-at-top
+# pylint: disable=unused-import,wildcard-import
+from tensorflow.contrib.tensorrt.python import *
+# pylint: enable=unused-import,wildcard-import
diff --git a/tensorflow/contrib/tensorrt/python/__init__.py b/tensorflow/contrib/tensorrt/python/__init__.py
index 75490ae..0cae401 100644
--- a/tensorflow/contrib/tensorrt/python/__init__.py
+++ b/tensorflow/contrib/tensorrt/python/__init__.py
@@ -19,12 +19,6 @@
from __future__ import print_function
# pylint: disable=unused-import,line-too-long
-from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops
-from tensorflow.contrib.tensorrt.python.trt_convert import add_test_value
from tensorflow.contrib.tensorrt.python.trt_convert import calib_graph_to_infer_graph
-from tensorflow.contrib.tensorrt.python.trt_convert import clear_test_values
from tensorflow.contrib.tensorrt.python.trt_convert import create_inference_graph
-from tensorflow.contrib.tensorrt.python.trt_convert import enable_test_value
-from tensorflow.contrib.tensorrt.python.trt_convert import get_test_value
-from tensorflow.contrib.tensorrt.python.trt_convert import is_tensorrt_enabled
# pylint: enable=unused-import,line-too-long
diff --git a/tensorflow/contrib/tensorrt/python/trt_convert.py b/tensorflow/contrib/tensorrt/python/trt_convert.py
index 49d7223..4a95937 100644
--- a/tensorflow/contrib/tensorrt/python/trt_convert.py
+++ b/tensorflow/contrib/tensorrt/python/trt_convert.py
@@ -18,411 +18,41 @@
from __future__ import division
from __future__ import print_function
-import six as _six
-# pylint: disable=unused-import,line-too-long
-from tensorflow.contrib.tensorrt.wrap_conversion import add_test_value
-from tensorflow.contrib.tensorrt.wrap_conversion import calib_convert
-from tensorflow.contrib.tensorrt.wrap_conversion import clear_test_values
-from tensorflow.contrib.tensorrt.wrap_conversion import enable_test_value
-from tensorflow.contrib.tensorrt.wrap_conversion import get_linked_tensorrt_version
-from tensorflow.contrib.tensorrt.wrap_conversion import get_loaded_tensorrt_version
-from tensorflow.contrib.tensorrt.wrap_conversion import get_test_value
-from tensorflow.contrib.tensorrt.wrap_conversion import is_tensorrt_enabled
-# pylint: enable=unused-import,line-too-long
-from tensorflow.core.framework import graph_pb2
-from tensorflow.core.protobuf import config_pb2
-from tensorflow.core.protobuf import meta_graph_pb2
-from tensorflow.core.protobuf import rewriter_config_pb2
-from tensorflow.python.client import session
-from tensorflow.python.framework import errors_impl as _impl
-from tensorflow.python.framework import graph_util
-from tensorflow.python.framework import importer
-from tensorflow.python.framework import ops
-from tensorflow.python.grappler import tf_optimizer
-from tensorflow.python.platform import tf_logging
-from tensorflow.python.saved_model import builder
-from tensorflow.python.saved_model import loader_impl
-from tensorflow.python.saved_model import tag_constants
-from tensorflow.python.training import saver
+from tensorflow.python.compiler.tensorrt import trt_convert
-def _to_bytes(s):
- """Encode s if it is a sequence of chars."""
- if isinstance(s, _six.text_type):
- return s.encode("utf-8", errors="surrogateescape")
- return s
-
-
-def _to_string(s):
- """Decode s if it is a sequence of bytes."""
- if isinstance(s, _six.binary_type):
- return s.decode("utf-8")
- return s
-
-
-class TrtPrecisionMode(object):
- FP32 = "FP32"
- FP16 = "FP16"
- INT8 = "INT8"
-
- @staticmethod
- def supported_precision_modes():
- return [TrtPrecisionMode.FP32, TrtPrecisionMode.FP16, TrtPrecisionMode.INT8]
-
-
-def get_tensorrt_rewriter_config(rewriter_config=None,
- max_batch_size=1,
- max_workspace_size_bytes=2 << 20,
- precision_mode=TrtPrecisionMode.FP32,
- minimum_segment_size=3,
- is_dynamic_op=False,
- maximum_cached_engines=1,
- cached_engine_batches=None,
- use_calibration=True):
- """Returns a RewriterConfig proto for TRT transformation.
-
- Args:
- rewriter_config: a template RewriterConfig proto used to create a
- TRT-enabled RewriterConfig. If None, it will use a default one.
- max_batch_size: max size for the input batch
- max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
- engine can use at execution time. This corresponds to the 'workspaceSize'
- parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
- precision_mode: one of TrtPrecisionMode.supported_precision_modes().
- minimum_segment_size: the minimum number of nodes required for a subgraph to
- be replaced by TRTEngineOp.
- is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT
- network and engine at run time.
- maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops.
- If the number of cached engines is already at max but none of them can
- serve the input, the TRTEngineOp will fall back to run the TF function
- based on which the TRTEngineOp is created.
- cached_engine_batches: a list of batch sizes used to create cached
- engines, only used when is_dynamic_op is True. The length of the list
- should be <= maximum_cached_engines, and the dynamic TRT op will
- use this list to determine the batch sizes of the cached engines, instead
- of making the decision on the fly. This is useful when we know the most
- common batch size(s) the application is going to generate.
- use_calibration: this argument is ignored if precision_mode is not INT8. If
- set to True, a calibration graph will be created to calibrate the missing
- ranges. The calibration graph must be converted to an inference graph
- using calib_graph_to_infer_graph() after running calibration. if set to
- False, quantization nodes will be expected for every tensor in the graph
- (exlcuding those which will be fused). If a range is missing, an error
- will occur. Please note that accuracy may be negatively affected if there
- is a mismatch between which tensors TRT quantizes and which tensors were
- trained with fake quantization.
-
- Returns:
- A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler.
-
- Raises:
- TypeError: if any of the parameters are of unexpected type.
- ValueError: if any of the parameters are of unexpected value.
- """
- if rewriter_config is not None and not isinstance(
- rewriter_config, rewriter_config_pb2.RewriterConfig):
- raise TypeError("rewriter_config should be a RewriterConfig proto.")
-
- rewriter_config_with_trt = rewriter_config_pb2.RewriterConfig()
- if rewriter_config is None:
- # Layout optimizer may add Const nodes followed by Reshape nodes, thus we
- # need to run constant folding again.
- rewriter_config_with_trt.optimizers.extend(
- ["constfold", "layout", "constfold"])
- rewriter_config_with_trt.meta_optimizer_iterations = (
- rewriter_config_pb2.RewriterConfig.ONE)
- else:
- rewriter_config_with_trt.CopyFrom(rewriter_config)
-
- if precision_mode.upper() not in TrtPrecisionMode.supported_precision_modes():
- raise ValueError(("precision mode '{}' is not supported."
- "It should be one of {}").format(
- precision_mode,
- TrtPrecisionMode.supported_precision_modes))
-
- optimizer = rewriter_config_with_trt.custom_optimizers.add()
- optimizer.name = "TensorRTOptimizer"
- optimizer.parameter_map["minimum_segment_size"].i = minimum_segment_size
- optimizer.parameter_map["max_batch_size"].i = max_batch_size
- optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op
- optimizer.parameter_map[
- "max_workspace_size_bytes"].i = max_workspace_size_bytes
- optimizer.parameter_map["precision_mode"].s = _to_bytes(precision_mode)
- optimizer.parameter_map["maximum_cached_engines"].i = maximum_cached_engines
- if cached_engine_batches:
- if not isinstance(cached_engine_batches, list):
- raise TypeError("cached_engine_batches should be a list.")
- if len(cached_engine_batches) > maximum_cached_engines:
- raise ValueError("cached_engine_batches should not contain more than "
- "maximum_cached_engines items.")
- optimizer.parameter_map["cached_engine_batches"].list.i.extend(
- cached_engine_batches)
- optimizer.parameter_map["use_calibration"].b = use_calibration
- return rewriter_config_with_trt
-
-
-def create_inference_graph(input_graph_def,
- outputs,
- max_batch_size=1,
- max_workspace_size_bytes=2 << 20,
- precision_mode=TrtPrecisionMode.FP32,
- minimum_segment_size=3,
- is_dynamic_op=False,
- maximum_cached_engines=1,
- cached_engine_batches=None,
- use_calibration=True,
- input_saved_model_dir=None,
- input_saved_model_tags=None,
- output_saved_model_dir=None,
- session_config=None):
- """Python wrapper for the TRT transformation.
-
- Args:
- input_graph_def: a GraphDef object containing a model to be transformed. If
- set to None, the graph will be read from the SavedModel loaded from
- input_saved_model_dir.
- outputs: list of tensors or node names for the model outputs. Only used when
- input_graph_def is not None.
- max_batch_size: max size for the input batch.
- max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
- engine can use at execution time. This corresponds to the 'workspaceSize'
- parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
- precision_mode: one of TrtPrecisionMode.supported_precision_modes().
- minimum_segment_size: the minimum number of nodes required for a subgraph to
- be replaced by TRTEngineOp.
- is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT
- network and engine at run time.
- maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops.
- If the number of cached engines is already at max but none of them can
- serve the input, the TRTEngineOp will fall back to run the TF function
- based on which the TRTEngineOp is created.
- cached_engine_batches: a list of batch sizes used to create cached
- engines, only used when is_dynamic_op is True. The length of the list
- should be <= maximum_cached_engines, and the dynamic TRT op will
- use this list to determine the batch sizes of the cached engines, instead
- of making the decision on the fly. This is useful when we know the most
- common batch size(s) the application is going to generate.
- use_calibration: this argument is ignored if precision_mode is not INT8. If
- set to True, a calibration graph will be created to calibrate the missing
- ranges. The calibration graph must be converted to an inference graph
- using calib_graph_to_infer_graph() after running calibration. if set to
- False, quantization nodes will be expected for every tensor in the graph
- (exlcuding those which will be fused). If a range is missing, an error
- will occur. Please note that accuracy may be negatively affected if there
- is a mismatch between which tensors TRT quantizes and which tensors were
- trained with fake quantization.
- input_saved_model_dir: the directory to load the SavedModel which contains
- the input graph to transforms. Used only when input_graph_def is None.
- input_saved_model_tags: list of tags to load the SavedModel.
- output_saved_model_dir: if not None, construct a SavedModel using the
- returned GraphDef and save it to the specified directory. This option only
- works when the input graph is loaded from a SavedModel, i.e. when
- input_saved_model_dir is specified and input_graph_def is None.
- session_config: the ConfigProto used to create a Session. It's also used as
- a template to create a TRT-enabled ConfigProto for conversion. If not
- specified, a default ConfigProto will be used.
-
- Returns:
- A GraphDef transformed from input_graph_def (or the SavedModel graph def
- loaded from input_saved_model_dir, if input_graph_def is not present), where
- all TRT compatible subgraphs are replaced with TRTEngineOps, and a TF
- function is added for each of the subgraphs.
-
- If is_dynamic_op is True, each TRTEngineOp will contain a serialized
- subgraph GraphDef, which will be converted to a TRT engine at execution time
- and the TRT engine will be cached for future usage. A new TRT engine will be
- created each time when none of the cached engines match the input shapes. If
- it fails to execute the TRT engine or the number of cached engines reaches
- maximum_cached_engines, the op will fall back to call the corresponding TF
- function.
-
- If is_dynamic_op is False, each TRTEngineOp will contain a serialized TRT
- engine created from the corresponding subgraph. No more engines will be
- created on the fly, and the op will fall back to call the corresponding TF
- function when it fails to execute the engine.
-
- Raises:
- ValueError: if the combination of the parameters is invalid.
- RuntimeError: if the TensorRT library version is incompatible.
- """
- compiled_version = get_linked_tensorrt_version()
- loaded_version = get_loaded_tensorrt_version()
- version_mismatch = False
- if loaded_version[0] < compiled_version[0]:
- tf_logging.error(
- "TensorRT version mismatch. Tensorflow was compiled against " +
- "TensorRT %s but library loaded from environment is TensorRT %s" %
- (".".join([str(x) for x in compiled_version]),
- ".".join([str(x) for x in loaded_version])) +
- ". Please make sure that correct version of TensorRT " +
- "is available in the system and added to ldconfig or LD_LIBRARY_PATH")
- raise RuntimeError("Incompatible TensorRT library version")
- for i in zip(loaded_version, compiled_version):
- if i[0] != i[1]:
- tf_logging.warn("TensorRT mismatch. Compiled against version " +
- "%s, but loaded %s. Things may not work" %
- (".".join([str(x) for x in compiled_version]),
- ".".join([str(x) for x in loaded_version])))
- version_mismatch = True
- break
- if not version_mismatch:
- tf_logging.info("Running against TensorRT version %s" % ".".join(
- [str(x) for x in loaded_version]))
-
- if session_config is None:
- session_config = config_pb2.ConfigProto()
-
- if input_saved_model_tags is None:
- input_saved_model_tags = [tag_constants.SERVING]
- saved_model_loader = None
- grappler_meta_graph_def = None
-
- if input_graph_def is None:
- # Read from SavedModel and freeze the graph if necessary.
- if input_saved_model_dir is None:
- raise ValueError("input_graph_def and input_saved_model_dir cannot be "
- "both None")
- with ops.Graph().as_default():
- with session.Session(config=session_config) as sess:
- saved_model_loader = loader_impl.SavedModelLoader(input_saved_model_dir)
- input_meta_graph_def = saved_model_loader.load(sess,
- input_saved_model_tags)
- output_node_names = set()
-
- def _gather_names(tensor_info):
- """Get the node names from a TensorInfo."""
- return set(
- [tensor_info[key].name.split(":")[0] for key in tensor_info])
-
- # Get input and outputs from all SignatureDef.
- for key in input_meta_graph_def.signature_def:
- signature_def = input_meta_graph_def.signature_def[key]
- output_node_names.update(_gather_names(signature_def.inputs))
- output_node_names.update(_gather_names(signature_def.outputs))
-
- # Freeze the variables in the SavedModel graph and copy the frozen
- # graph over.
- frozen_graph_def = graph_util.convert_variables_to_constants(
- sess, sess.graph.as_graph_def(add_shapes=True),
- list(output_node_names))
- grappler_meta_graph_def = meta_graph_pb2.MetaGraphDef()
- grappler_meta_graph_def.graph_def.CopyFrom(frozen_graph_def)
-
- # Copy the collections that are not variables.
- for key in input_meta_graph_def.collection_def:
- # TODO(laigd): currently we use the collection key to filter out
- # collections that depend on variable ops, but this may miss some
- # other user-defined collections. A better way would be to use
- # CollectionDef::NodeList for the filtering.
- if key not in [
- "variables", "local_variables", "model_variables",
- "trainable_variables", "train_op", "table_initializer"
- ]:
- grappler_meta_graph_def.collection_def[key].CopyFrom(
- input_meta_graph_def.collection_def[key])
-
- # Copy other information.
- grappler_meta_graph_def.meta_info_def.CopyFrom(
- input_meta_graph_def.meta_info_def)
- for key in input_meta_graph_def.signature_def:
- grappler_meta_graph_def.signature_def[key].CopyFrom(
- input_meta_graph_def.signature_def[key])
- # TODO(laigd): maybe add back AssetFileDef.
- else:
- if output_saved_model_dir is not None:
- raise ValueError("output_saved_model_dir cannot be set when "
- "input_graph_def is set")
- # Create MetaGraphDef from input graph.
- graph = ops.Graph()
- with graph.as_default():
- importer.import_graph_def(input_graph_def, name="")
- grappler_meta_graph_def = saver.export_meta_graph(
- graph_def=graph.as_graph_def(add_shapes=True), graph=graph)
- if outputs:
- output_collection = meta_graph_pb2.CollectionDef()
- output_list = output_collection.node_list.value
- for i in outputs:
- if isinstance(i, ops.Tensor):
- output_list.append(_to_bytes(i.name))
- else:
- output_list.append(_to_bytes(i))
- # TODO(laigd): use another key as the outputs are really not train_op.
- grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
- output_collection)
-
- # Create TRT-enabled ConfigProto.
- session_config_with_trt = config_pb2.ConfigProto()
- session_config_with_trt.CopyFrom(session_config)
- rewriter_config = None
- if (session_config_with_trt.HasField("graph_options") and
- session_config_with_trt.graph_options.HasField("rewrite_options")):
- rewriter_config = session_config_with_trt.graph_options.rewrite_options
- rewriter_config_with_trt = get_tensorrt_rewriter_config(
- rewriter_config, max_batch_size, max_workspace_size_bytes, precision_mode,
- minimum_segment_size, is_dynamic_op, maximum_cached_engines,
- cached_engine_batches, use_calibration)
- session_config_with_trt.graph_options.rewrite_options.CopyFrom(
- rewriter_config_with_trt)
-
- # Run Grappler.
- transformed_graph_def = tf_optimizer.OptimizeGraph(
- session_config_with_trt, grappler_meta_graph_def, graph_id=b"tf_graph")
-
- # Optionally write the transformed graphdef as SavedModel.
- if output_saved_model_dir is not None:
- saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir)
- with ops.Graph().as_default():
- importer.import_graph_def(transformed_graph_def, name="")
- # We don't use TRT here.
- with session.Session(config=session_config) as sess:
- saved_model_builder.add_meta_graph_and_variables(
- sess,
- input_saved_model_tags,
- signature_def_map=grappler_meta_graph_def.signature_def)
- # Ignore other meta graphs from the input SavedModel.
- saved_model_builder.save()
-
- return transformed_graph_def
+def create_inference_graph(
+ input_graph_def,
+ outputs,
+ max_batch_size=1,
+ max_workspace_size_bytes=trt_convert.DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
+ precision_mode=trt_convert.TrtPrecisionMode.FP32,
+ minimum_segment_size=3,
+ is_dynamic_op=False,
+ maximum_cached_engines=1,
+ cached_engine_batches=None,
+ use_calibration=True,
+ input_saved_model_dir=None,
+ input_saved_model_tags=None,
+ output_saved_model_dir=None,
+ session_config=None):
+ return trt_convert.create_inference_graph(
+ input_graph_def=input_graph_def,
+ outputs=outputs,
+ max_batch_size=max_batch_size,
+ max_workspace_size_bytes=max_workspace_size_bytes,
+ precision_mode=precision_mode,
+ minimum_segment_size=minimum_segment_size,
+ is_dynamic_op=is_dynamic_op,
+ maximum_cached_engines=maximum_cached_engines,
+ cached_engine_batches=cached_engine_batches,
+ use_calibration=use_calibration,
+ input_saved_model_dir=input_saved_model_dir,
+ input_saved_model_tags=input_saved_model_tags,
+ output_saved_model_dir=output_saved_model_dir,
+ session_config=session_config)
def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False):
- """Convert an existing calibration graph to inference graph.
-
- Args:
- calibration_graph_def: the calibration GraphDef object with calibration data
- is_dynamic_op: whether to create dynamic static engines from calibration
-
- Returns:
- New GraphDef with TRTEngineOps placed in graph replacing calibration nodes.
- Raises:
- RuntimeError: if the returned status message is malformed.
- """
-
- is_calib_graph = False
- for n in calibration_graph_def.node:
- if n.op == "TRTEngineOp":
- is_calib_graph = is_calib_graph or not n.attr["calibration_data"].s
- if not is_calib_graph:
- tf_logging.error(
- "Not a calib graph. Doesn't seem to contain any calibration nodes.")
- return None
- graph_str = calibration_graph_def.SerializeToString()
- out = calib_convert(graph_str, is_dynamic_op)
- status = _to_string(out[0])
- output_graph_def_string = out[1]
- del graph_str # Save some memory
- if len(status) < 2:
- raise _impl.UnknownError(None, None, status)
- if status[:2] != "OK":
- msg = status.split(";")
- if len(msg) == 1:
- raise RuntimeError("Status message is malformed {}".format(status))
- # pylint: disable=protected-access
- raise _impl._make_specific_exception(None, None, ";".join(msg[1:]),
- int(msg[0]))
- # pylint: enable=protected-access
- output_graph_def = graph_pb2.GraphDef()
- output_graph_def.ParseFromString(output_graph_def_string)
- del output_graph_def_string # Save some memory
- return output_graph_def
+ return trt_convert.calib_graph_to_infer_graph(
+ calibration_graph_def=calibration_graph_def, is_dynamic_op=is_dynamic_op)
diff --git a/tensorflow/contrib/tensorrt/test/test_tftrt.py b/tensorflow/contrib/tensorrt/test/test_tftrt.py
deleted file mode 100644
index 090aa8b..0000000
--- a/tensorflow/contrib/tensorrt/test/test_tftrt.py
+++ /dev/null
@@ -1,287 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Script to test TF-TensorRT integration."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import argparse
-import numpy as np
-import six as _six
-
-# normally we should do import tensorflow as tf and then
-# tf.placeholder, tf.constant, tf.nn.conv2d etc but
-# it looks like internal builds don't like it so
-# importing every module individually
-
-from tensorflow.contrib import tensorrt as trt
-from tensorflow.core.protobuf import config_pb2 as cpb2
-from tensorflow.core.protobuf import rewriter_config_pb2 as rwpb2
-from tensorflow.python.client import session as csess
-from tensorflow.python.framework import constant_op as cop
-from tensorflow.python.framework import dtypes as dtypes
-from tensorflow.python.framework import importer as importer
-from tensorflow.python.framework import ops as ops
-from tensorflow.python.ops import array_ops as aops
-from tensorflow.python.ops import math_ops as mops
-from tensorflow.python.ops import nn as nn
-from tensorflow.python.ops import nn_ops as nn_ops
-
-
-def py2bytes(inp):
- return inp
-
-
-def py3bytes(inp):
- return inp.encode("utf-8", errors="surrogateescape")
-
-
-def py2string(inp):
- return inp
-
-
-def py3string(inp):
- return inp.decode("utf-8")
-
-
-if _six.PY2:
- to_bytes = py2bytes
- to_string = py2string
-else:
- to_bytes = py3bytes
- to_string = py3string
-
-
-def get_multi_engine_graph_def(mode="FP32"):
- """Create a simple graph and return its graph_def."""
- dtype = dtypes.float32
- if mode.upper() == "FP16":
- dtype = dtypes.float16
- else:
- pass
-
- g = ops.Graph()
- with g.as_default():
- x = aops.placeholder(shape=[None, 3, 7, 5], name="input", dtype=dtype)
- with g.name_scope("Global_scope"):
- with g.name_scope("first_scope"):
- e = cop.constant(
- np.random.randn(3, 2, 3, 4), name="weights", dtype=dtype)
- conv = nn.conv2d(
- input=x,
- filter=e,
- data_format="NCHW",
- strides=[1, 1, 1, 1],
- padding="VALID",
- name="conv")
- b = cop.constant(np.random.randn(1, 4, 1, 1), name="bias1", dtype=dtype)
- t = conv * b
-
- b = cop.constant(np.random.randn(1, 4, 1, 1), name="bias2", dtype=dtype)
- q = conv / b
- edge = mops.sin(q)
- edge1 = mops.cos(conv)
- with g.name_scope("test_scope"):
- de = edge + edge1
- t -= edge1
- q *= edge
- t += q
- t -= de
- k = aops.squeeze(t, name="output")
- print(k.dtype)
- return g.as_graph_def()
-
-
-def get_simple_graph_def():
- """Create a simple graph and return its graph_def."""
- g = ops.Graph()
- with g.as_default():
- a = aops.placeholder(
- dtype=dtypes.float32, shape=(None, 24, 24, 2), name="input")
- e = cop.constant(
- [[[[1., 0.5, 4., 6., 0.5, 1.], [1., 0.5, 1., 1., 0.5, 1.]]]],
- name="weights",
- dtype=dtypes.float32)
- conv = nn.conv2d(
- input=a, filter=e, strides=[1, 2, 2, 1], padding="SAME", name="conv")
- b = cop.constant(
- [4., 1.5, 2., 3., 5., 7.], name="bias", dtype=dtypes.float32)
- t = nn.bias_add(conv, b, name="biasAdd")
- relu = nn.relu(t, "relu")
- idty = aops.identity(relu, "ID")
- v = nn_ops.max_pool(
- idty, [1, 2, 2, 1], [1, 2, 2, 1], "VALID", name="max_pool")
- aops.squeeze(v, name="output")
- return g.as_graph_def()
-
-
-def execute_graph(gdef, dumm_inp):
- """Run given graphdef once."""
- print("executing")
- gpu_options = None
- if trt.trt_convert.get_linked_tensorrt_version()[0] == 3:
- gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50)
- sessconfig = cpb2.ConfigProto(gpu_options=gpu_options)
- ops.reset_default_graph()
- g = ops.Graph()
- with g.as_default():
- inp, out = importer.import_graph_def(
- graph_def=gdef, return_elements=["input", "output"])
- inp = inp.outputs[0]
- out = out.outputs[0]
- with csess.Session(config=sessconfig, graph=g) as sess:
- val = sess.run(out, {inp: dumm_inp})
- return val
-
-
-# Use real data that is representative of the inference dataset
-# for calibration. For this test script it is random data.
-def execute_calibration(gdef, dumm_inp):
- """Run given calibration graph multiple times."""
- gpu_options = None
- if trt.trt_convert.get_linked_tensorrt_version()[0] == 3:
- gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50)
- ops.reset_default_graph()
- g = ops.Graph()
- with g.as_default():
- inp, out = importer.import_graph_def(
- graph_def=gdef, return_elements=["input", "output"])
- inp = inp.outputs[0]
- out = out.outputs[0]
- with csess.Session(
- config=cpb2.ConfigProto(gpu_options=gpu_options), graph=g) as sess:
- # run over real calibration data here, we are mimicking a calibration set of
- # 30 different batches. Use as much calibration data as you want
- for _ in range(30):
- val = sess.run(out, {inp: dumm_inp})
- return val
-
-
-def user(multi_engine,
- run_graph=execute_graph,
- run_calibration=execute_calibration):
- """Example function that converts a graph to TFTRT graph."""
- if multi_engine:
- inp_dims = (2, 3, 7, 5)
- orig_graph = get_multi_engine_graph_def()
- else:
- inp_dims = (100, 24, 24, 2)
- orig_graph = get_simple_graph_def() # use a frozen graph for inference
- dummy_input = np.random.random_sample(inp_dims)
- # Get optimized graph
- trt_graph = trt.create_inference_graph(
- input_graph_def=orig_graph,
- outputs=["output"],
- max_batch_size=inp_dims[0],
- max_workspace_size_bytes=1 << 25,
- precision_mode="FP32", # TRT Engine precision "FP32","FP16" or "INT8"
- minimum_segment_size=2, # minimum number of nodes in an engine
- is_dynamic_op=False,
- maximum_cached_engines=1,
- cached_engine_batches=[])
- o1 = run_graph(orig_graph, dummy_input)
- o2 = run_graph(trt_graph, dummy_input)
- o3 = run_graph(trt_graph, dummy_input)
- assert np.array_equal(o1, o2)
- assert np.array_equal(o3, o2) # sanity check
- fp16_graph = trt.create_inference_graph(
- input_graph_def=orig_graph,
- outputs=["output"],
- max_batch_size=inp_dims[0],
- max_workspace_size_bytes=1 << 25,
- precision_mode="FP16", # TRT Engine precision "FP32","FP16" or "INT8"
- minimum_segment_size=2, # minimum number of nodes in an engine
- is_dynamic_op=False,
- maximum_cached_engines=1,
- cached_engine_batches=[])
- int8_calib_gdef = trt.create_inference_graph(
- input_graph_def=orig_graph,
- outputs=["output"],
- max_batch_size=inp_dims[0],
- max_workspace_size_bytes=1 << 25,
- precision_mode="INT8", # TRT Engine precision "FP32","FP16" or "INT8"
- minimum_segment_size=2, # minimum number of nodes in an engine
- is_dynamic_op=False,
- maximum_cached_engines=1,
- cached_engine_batches=[])
- o4 = run_graph(fp16_graph, dummy_input)
- _ = run_calibration(int8_calib_gdef, dummy_input)
- int8_graph = trt.calib_graph_to_infer_graph(int8_calib_gdef)
- o5 = run_graph(int8_graph, dummy_input)
- print("Is FP32 == FP16? %s (False is possible)" % np.allclose(o1, o4))
- print("Is FP32 == INT8? %s (False is possible)" % np.allclose(o1, o5))
- print("Pass")
-
-
-def auto(multi_engine):
- """Run the conversion as an optimization pass."""
- if multi_engine:
- inp_dims = (2, 3, 7, 5)
- orig_graph = get_multi_engine_graph_def()
- else:
- inp_dims = (100, 24, 24, 2)
- orig_graph = get_simple_graph_def() # use a frozen graph for inference
- dummy_input = np.random.random_sample(inp_dims)
- opt_config = rwpb2.RewriterConfig()
- opt_config.meta_optimizer_iterations = opt_config.ONE
- opt_config.optimizers.extend(["constfold", "layout"])
- custom_op = opt_config.custom_optimizers.add()
- custom_op.name = "TensorRTOptimizer"
- custom_op.parameter_map["minimum_segment_size"].i = 3
- custom_op.parameter_map["precision_mode"].s = to_bytes("FP32")
- custom_op.parameter_map["max_batch_size"].i = inp_dims[0]
- custom_op.parameter_map["max_workspace_size_bytes"].i = 1 << 25
- print(custom_op)
- gpu_options = None
- if trt.trt_convert.get_linked_tensorrt_version()[0] == 3:
- gpu_options = cpb2.GPUOptions(per_process_gpu_memory_fraction=0.50)
- graph_options = cpb2.GraphOptions(rewrite_options=opt_config)
- sessconfig = cpb2.ConfigProto(
- gpu_options=gpu_options, graph_options=graph_options)
- print(sessconfig)
- g = ops.Graph()
- ops.reset_default_graph()
- with g.as_default():
- inp, out = importer.import_graph_def(
- graph_def=orig_graph, return_elements=["input", "output"], name="")
- inp = inp.outputs[0]
- out = out.outputs[0]
- with csess.Session(config=sessconfig, graph=g) as sess:
- val = sess.run(out, {inp: dummy_input})
- print(val.shape)
-
-
-if "__main__" in __name__:
- P = argparse.ArgumentParser(
- prog="tftrt_test",
- description="Example utilization of TensorFlow-TensorRT integration")
- P.add_argument(
- "--automatic",
- "-a",
- action="store_true",
- help="Do TRT conversion automatically",
- default=False)
- P.add_argument(
- "--multi-engine",
- "-m",
- action="store_true",
- help="Use a graph that will result in 2 engines",
- default=False)
- flags, unparsed = P.parse_known_args()
- if flags.automatic:
- auto(flags.multi_engine)
- else:
- user(flags.multi_engine)
diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD
index 2a22295..d1be31d 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/BUILD
+++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD
@@ -155,11 +155,11 @@
py_test(
name = "head_test",
- size = "large",
+ size = "medium",
srcs = [
"head_test.py",
],
- shard_count = 4,
+ shard_count = 10,
srcs_version = "PY2AND3",
tags = ["no_pip_gpu"], # b/63391119
deps = [
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD
index c1a36fe..7b1a501 100644
--- a/tensorflow/contrib/tpu/BUILD
+++ b/tensorflow/contrib/tpu/BUILD
@@ -112,12 +112,12 @@
"functional_ops",
],
deps = [
- "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_cc",
"//tensorflow/contrib/tpu/utils:tpu_embedding_optimization_parameters_utils",
"//tensorflow/contrib/tpu/utils:tpu_embedding_output_layout_utils",
"//tensorflow/core:lib",
"//tensorflow/core:lib_proto_parsing",
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_cc",
],
)
@@ -134,10 +134,10 @@
"ops/tpu_embedding_ops.cc",
],
deps = [
- "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_cc",
"//tensorflow/contrib/tpu/utils:tpu_embedding_optimization_parameters_utils",
"//tensorflow/contrib/tpu/utils:tpu_embedding_output_layout_utils",
"//tensorflow/core:lib_proto_parsing",
+ "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_cc",
],
)
@@ -162,14 +162,14 @@
)
tf_custom_op_library(
- name = "python/ops/_tpu_ordinal_selector.so",
+ name = "python/ops/_tpu_ordinal_selector_op.so",
srcs = ["ops/tpu_ordinal_selector_op.cc"],
)
tf_custom_op_py_library(
name = "tpu_ordinal_selector_py",
- srcs = ["ops/gen_tpu_ordinal_selector_op.py"],
- dso = [":python/ops/_tpu_ordinal_selector.so"],
+ srcs = ["python/ops/tpu_ordinal_selector_op.py"],
+ dso = [":python/ops/_tpu_ordinal_selector_op.so"],
kernels = [
":tpu_ordinal_selector_op_op_lib",
],
@@ -187,6 +187,11 @@
],
)
+tf_custom_op_library(
+ name = "python/ops/_functional_ops.so",
+ srcs = ["ops/functional_ops.cc"],
+)
+
tf_gen_op_wrapper_py(
name = "gen_functional_ops",
out = "python/tpu/gen_functional_ops.py",
@@ -196,9 +201,14 @@
deps = [":functional_ops_op_lib"],
)
-py_library(
+tf_custom_op_py_library(
name = "functional",
srcs = ["python/tpu/functional.py"],
+ dso = [":python/ops/_functional_ops.so"],
+ kernels = [
+ ":functional_ops_op_lib",
+ ],
+ srcs_version = "PY2AND3",
visibility = [
"//visibility:public",
],
@@ -221,7 +231,7 @@
tf_custom_op_py_library(
name = "tpu_py",
- srcs = glob(["python/ops/*.py"]),
+ srcs = ["python/ops/tpu_ops.py"],
dso = [":python/ops/_tpu_ops.so"],
kernels = [
":all_ops",
@@ -274,8 +284,8 @@
"//tensorflow/contrib/cluster_resolver:cluster_resolver_py",
"//tensorflow/contrib/distribute",
"//tensorflow/contrib/framework:framework_py",
- "//tensorflow/contrib/tpu/proto:compilation_result_proto_py",
"//tensorflow/core:protos_all_py",
+ "//tensorflow/core/protobuf/tpu:compilation_result_proto_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
@@ -321,13 +331,13 @@
"//tensorflow/compiler/xla/python_api:xla_shape",
"//tensorflow/contrib/cluster_resolver:cluster_resolver_py",
"//tensorflow/contrib/compiler:xla",
- "//tensorflow/contrib/tpu/proto:compilation_result_proto_py",
- "//tensorflow/contrib/tpu/proto:dynamic_padding_proto_py",
- "//tensorflow/contrib/tpu/proto:optimization_parameters_proto_py",
- "//tensorflow/contrib/tpu/proto:topology_proto_py",
- "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_py",
- "//tensorflow/contrib/tpu/proto:tpu_embedding_output_layout_proto_py",
"//tensorflow/core:protos_all_py",
+ "//tensorflow/core/protobuf/tpu:compilation_result_proto_py",
+ "//tensorflow/core/protobuf/tpu:dynamic_padding_proto_py",
+ "//tensorflow/core/protobuf/tpu:optimization_parameters_proto_py",
+ "//tensorflow/core/protobuf/tpu:topology_proto_py",
+ "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_py",
+ "//tensorflow/core/protobuf/tpu:tpu_embedding_output_layout_proto_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:control_flow_util",
@@ -459,7 +469,7 @@
deps = [
":tpu_lib",
":tpu_ops",
- "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_py",
+ "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:init_ops",
diff --git a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
index 676aed0..b991698 100644
--- a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
+++ b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
@@ -13,7 +13,6 @@
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/tpu/proto/tpu_embedding_configuration.pb.h"
#include "tensorflow/contrib/tpu/utils/tpu_embedding_optimization_parameters_utils.h"
#include "tensorflow/contrib/tpu/utils/tpu_embedding_output_layout_utils.h"
#include "tensorflow/core/framework/attr_value.pb.h"
@@ -23,6 +22,7 @@
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/protobuf/tpu/tpu_embedding_configuration.pb.h"
namespace tensorflow {
@@ -466,7 +466,7 @@
configuration given to tpu.initialize_system.
learning_rates: A TensorList of float32 scalars, one for each dynamic learning
rate tag: see the comments in
- //third_party/tensorflow/contrib/tpu/proto/optimization_parameters.proto.
+ //third_party/tensorflow/core/protobuf/tpu/optimization_parameters.proto.
Multiple tables can share the same dynamic learning rate tag as specified
in the configuration. If the learning rates for all tables are constant,
this list should be empty.
diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py
index 500dd2c..55f7c6b 100644
--- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py
+++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py
@@ -225,7 +225,7 @@
config: Serialized TPUEmbeddingConfiguration proto.
learning_rates: A TensorList of float32 scalars, one for each dynamic
learning rate tag: see the comments in
- //third_party/tensorflow/contrib/tpu/proto/
+ //third_party/tensorflow/core/protobuf/tpu/
optimization_parameters.proto.
Multiple tables can share the same dynamic learning rate tag as
specified in the configuration. If the learning rates for all tables
diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ordinal_selector_op.py b/tensorflow/contrib/tpu/python/ops/tpu_ordinal_selector_op.py
new file mode 100644
index 0000000..5ca38cd
--- /dev/null
+++ b/tensorflow/contrib/tpu/python/ops/tpu_ordinal_selector_op.py
@@ -0,0 +1,38 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+
+"""Operations to select TPU core to run."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import platform
+
+if platform.system() != "Windows":
+ # pylint: disable=wildcard-import,unused-import,g-import-not-at-top
+ from tensorflow.contrib.tpu.ops.gen_tpu_ordinal_selector_op import *
+
+ from tensorflow.contrib.util import loader
+ from tensorflow.python.platform import resource_loader
+ # pylint: enable=wildcard-import,unused-import,g-import-not-at-top
+
+ _tpu_ordinal_selector_op = loader.load_op_library(
+ resource_loader.get_path_to_datafile("_tpu_ordinal_selector_op.so"))
+
+else:
+ # We have already built the appropriate libraries into the binary via CMake
+ # if we have built contrib, so we don't need this
+ pass
diff --git a/tensorflow/contrib/tpu/python/tpu/feature_column.py b/tensorflow/contrib/tpu/python/tpu/feature_column.py
index 8edf131..68bcdb5 100644
--- a/tensorflow/contrib/tpu/python/tpu/feature_column.py
+++ b/tensorflow/contrib/tpu/python/tpu/feature_column.py
@@ -420,7 +420,7 @@
else:
# scope contains var_scope_name.
captured_scope = variable_scope.get_variable_scope()
- var_def_dict[embedding_var_name] = (captured_scope,
+ var_def_dict[embedding_var_name] = (captured_scope.name,
embedding_var_name_in_fc)
diff --git a/tensorflow/contrib/tpu/python/tpu/functional.py b/tensorflow/contrib/tpu/python/tpu/functional.py
index 1ec9b5b..24c8515 100644
--- a/tensorflow/contrib/tpu/python/tpu/functional.py
+++ b/tensorflow/contrib/tpu/python/tpu/functional.py
@@ -18,8 +18,22 @@
from __future__ import division
from __future__ import print_function
+import platform
+
from tensorflow.contrib.tpu.python.tpu import gen_functional_ops
TPUPartitionedCall = gen_functional_ops._tpu_partitioned_call # pylint: disable=invalid-name,protected-access
+
+if platform.system() != "Windows":
+ # pylint: disable=wildcard-import,unused-import,g-import-not-at-top
+ from tensorflow.contrib.tpu.ops.gen_tpu_ordinal_selector_op import *
+
+ from tensorflow.contrib.util import loader
+ from tensorflow.python.platform import resource_loader
+ # pylint: enable=wildcard-import,unused-import,g-import-not-at-top
+
+ _tpu_partitioned_call_op = loader.load_op_library(
+ resource_loader.get_path_to_datafile("../ops/_functional_ops.so")
+ )
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py
index 37fe9af..4322d17 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py
@@ -56,7 +56,6 @@
from tensorflow.contrib.cluster_resolver.python.training import tpu_cluster_resolver as tpu_cluster_resolver_lib
from tensorflow.contrib.framework.python.framework import experimental
-from tensorflow.contrib.tpu.proto import compilation_result_pb2 as tpu_compilation_result
from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.tpu.python.tpu import keras_tpu_variables
from tensorflow.contrib.tpu.python.tpu import tpu
@@ -64,6 +63,7 @@
from tensorflow.contrib.tpu.python.tpu import tpu_optimizer
from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf.tpu import compilation_result_pb2 as tpu_compilation_result
from tensorflow.python.client import session as tf_session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
diff --git a/tensorflow/contrib/tpu/python/tpu/session_support.py b/tensorflow/contrib/tpu/python/tpu/session_support.py
index f5735ce..5cb2ca6 100644
--- a/tensorflow/contrib/tpu/python/tpu/session_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/session_support.py
@@ -172,7 +172,8 @@
"""Shutdown all workers after `shutdown_timeout_secs`."""
logging.info('Shutting down %s.', self)
req = event_pb2.WorkerHeartbeatRequest(
- watchdog_config=event_pb2.WatchdogConfig(timeout_ms=timeout_ms))
+ watchdog_config=event_pb2.WatchdogConfig(timeout_ms=timeout_ms),
+ shutdown_mode=event_pb2.WAIT_FOR_COORDINATOR)
self.configure(req)
# Wait for workers to shutdown. This isn't strictly required
diff --git a/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py b/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py
index bf492e7..2c5ea65 100644
--- a/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py
+++ b/tensorflow/contrib/tpu/python/tpu/tensor_tracer.py
@@ -30,11 +30,15 @@
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import gen_math_ops
+from tensorflow.python.ops import init_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import logging_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
@@ -67,14 +71,18 @@
_SECTION_NAME_REASON = 'reason'
_SECTION_NAME_OP_LIST = 'op-list'
_SECTION_NAME_TENSOR_LIST = 'tensor-list'
+_SECTION_NAME_CACHE_INDEX_MAP = 'cache-index-map'
_SECTION_NAME_GRAPH = 'graph'
_FIELD_NAME_VERSION = 'version:'
_FIELD_NAME_DEVICE = 'device:'
_FIELD_NAME_TRACE_MODE = 'trace-mode:'
_FIELD_NAME_SUBMODE = 'submode:'
_FIELD_NAME_NUM_REPLICAS = 'num-replicas:'
+_FIELD_NAME_NUM_REPLICAS_PER_HOST = 'num-replicas-per-host:'
+_FIELD_NAME_NUM_HOSTS = 'num-hosts:'
_FIELD_NAME_NUM_OPS = 'number-of-ops:'
_FIELD_NAME_NUM_TENSORS = 'number-of-tensors:'
+_FIELD_NAME_NUM_CACHE_INDICES = 'number-of-indices:'
_FIELD_NAME_TOPOLOGICAL_SORT_SUCCEED = 'topological-sort-succeed:'
_FLAGS_ENV_VAR = 'TENSOR_TRACER_FLAGS'
_FLAG_SINGLE_QUOTE_PAT = re.compile(r"\s*--([^=]+)='([^']*)'")
@@ -83,14 +91,15 @@
_FLAG_NO_EQUAL_PAT = re.compile(r'\s*--([^=]+)\s*')
_FLAG_NAME_ENABLE = 'enable'
_FLAG_NAME_TRACE_MODE = 'trace_mode'
+_FLAG_NAME_USE_COMPACT_TRACE = 'compact_trace'
_FLAG_NAME_SUBMODE = 'submode'
_FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS = 'include_less_interesting_ops'
_FLAG_NAME_EXCLUDED_OPNAMES = 'excluded_opnames'
_FLAG_NAME_EXCLUDED_OPTYPES = 'excluded_optypes'
_FLAG_NAME_INCLUDED_OPNAMES = 'included_opnames'
_FLAG_NAME_INCLUDED_OPTYPES = 'included_optypes'
-_FLAG_NAME_TRACE_FILE = 'trace_file_path'
-_FLAG_NAME_REPORT_FILE = 'report_file_path'
+_FLAG_NAME_TRACE_DIR = 'trace_dir'
+_FLAG_NAME_REPORT_FILE = 'report_file'
_FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR = 'use_test_undeclared_outputs_dir'
_FLAG_NAME_OP_RANGE = 'op_range'
_OP_RANGE_PAT = re.compile(r'(\d+):(\d+)')
@@ -98,7 +107,12 @@
_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR = 'TEST_UNDECLARED_OUTPUTS_DIR'
_TENSOR_TRACER_COLLECTION = 'tensor_tracer_variables'
_TENSOR_TRACER_CHECKPOINT = 'tensor_tracer_checkpoint'
-
+_TRACE_FILE_NAME = 'trace.all'
+_COMPACT_TRACE_FILE_PREFIX = 'compact_trace.'
+_COMPACT_TRACE_ENTRY_INIT_VALUE = -1.0
+_TENSOR_TRACER_STORAGE = 'tensor_tracer_storage'
+_TENSOR_VALUES_CACHE = 'tensor_values_cache'
+_REPLICA_ID_TAG = '#replica-id: '
def tensor_tracepoint(tensor, checkpoint_name):
"""Adds a checkpoint with the given checkpoint name for the given tensor.
@@ -152,6 +166,68 @@
return layer
+def _trace_files_need_precreated(output_dir):
+ """Return True if trace files must be pre-created by users."""
+
+ if not output_dir.startswith('/'):
+ return False
+ if len(output_dir) < 5:
+ return False
+ if output_dir[2] != 'n':
+ return False
+ if output_dir[3] != 's':
+ return False
+ if output_dir[1] != 'c':
+ return False
+ if output_dir[4] != '/':
+ return False
+ return True
+
+
+def _get_tensor_values_cache(graph=None):
+ """Returns the variable that implements tensor-value caching."""
+
+ graph = graph or ops.get_default_graph()
+ collection = graph.get_collection(_TENSOR_TRACER_STORAGE)
+ if len(collection) == 1:
+ return collection[0]
+ elif not collection:
+ raise RuntimeError('%s has not been created'%_TENSOR_VALUES_CACHE)
+ else:
+ raise RuntimeError('Multiple %s created'%_TENSOR_VALUES_CACHE)
+ return None
+
+
+def _create_tensor_values_cache(graph, num_tensors):
+ """Creates a variable as the cache to store intermediate tensor values."""
+
+ graph = graph or ops.get_default_graph()
+ # Create in proper graph and base name_scope.
+ with graph.as_default() as g, g.name_scope(None):
+ return variable_scope.get_variable(
+ _TENSOR_VALUES_CACHE,
+ shape=[num_tensors],
+ dtype=dtypes.float32,
+ initializer=init_ops.constant_initializer(
+ _COMPACT_TRACE_ENTRY_INIT_VALUE),
+ trainable=False,
+ use_resource=True,
+ collections=[_TENSOR_TRACER_STORAGE, ops.GraphKeys.GLOBAL_VARIABLES])
+
+
+def _set_fetches(result_tensor, train_op):
+ """Sets the fetches from the result tensor and training op."""
+
+ fetches = []
+ if result_tensor is not None:
+ fetches.append(result_tensor)
+ if train_op is not None:
+ fetches.append(train_op)
+ if not fetches:
+ return None
+ return fetches
+
+
class TensorTracer(object):
"""A software construct for tracing tensor values in a TF graph on TPU.
@@ -203,12 +279,14 @@
def validate_flag_names():
"""Validates if the TensorTrace flags passed are valid."""
valid_flag_names = [_FLAG_NAME_ENABLE, _FLAG_NAME_TRACE_MODE,
+ _FLAG_NAME_USE_COMPACT_TRACE,
_FLAG_NAME_SUBMODE,
_FLAG_NAME_EXCLUDED_OPNAMES,
_FLAG_NAME_EXCLUDED_OPTYPES,
_FLAG_NAME_INCLUDED_OPNAMES,
_FLAG_NAME_INCLUDED_OPTYPES,
- _FLAG_NAME_TRACE_FILE, _FLAG_NAME_REPORT_FILE,
+ _FLAG_NAME_TRACE_DIR,
+ _FLAG_NAME_REPORT_FILE,
_FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR,
_FLAG_NAME_INCLUDE_LESS_INTERESTING_OPS,
_FLAG_NAME_OP_RANGE]
@@ -338,6 +416,10 @@
return TensorTracer._is_flag_on(
_FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR)
+ @staticmethod
+ def use_compact_trace():
+ return TensorTracer._is_flag_on(
+ _FLAG_NAME_USE_COMPACT_TRACE)
@staticmethod
def check_device_type(device_type):
@@ -535,7 +617,7 @@
TensorTracer.check_submode(self._submode)
self._part_tensor_size = _TRACE_MODE_PART_TENSOR_SIZE
self._instrument_records = {}
- self._set_trace_file_path()
+ self._set_trace_dir()
self._set_report_file()
self._set_op_range()
self._set_excluded_opnames()
@@ -543,17 +625,17 @@
self._set_included_opnames()
self._set_included_optypes()
self._num_replicas = None
+ self._num_replicas_per_host = None
+ self._num_hosts = None
self._replica_id = None
- def _add_replica_id_to_graph(self, num_replicas, result_tensor):
+ def _add_replica_id_to_graph(self, result_tensor):
"""Adds nodes for computing the replica ID to the graph."""
- if not num_replicas:
+ if not self._num_replicas:
self._replica_id = 'unknown'
return result_tensor
- self._num_replicas = num_replicas
-
with ops.control_dependencies(None):
# Uses None as dependency to run outside of TPU graph rewrites.
self._replica_id = tpu_ops.tpu_replicated_input(
@@ -565,20 +647,15 @@
# the replica_id to ensure that replica_id will be added to the graph.
return array_ops.identity(result_tensor)
- def _set_trace_file_path(self):
- """Sets the path of the output trace file."""
-
- found, self._trace_file_path = TensorTracer.get_flag_value(
- _FLAG_NAME_TRACE_FILE)
- if found and self._trace_file_path \
+ def _set_trace_dir(self):
+ found, self._trace_dir = TensorTracer.get_flag_value(_FLAG_NAME_TRACE_DIR)
+ if found and self._trace_dir \
and TensorTracer.use_test_undeclared_outputs_dir():
- if os.path.isabs(self._trace_file_path):
- raise ValueError('If use_test_undeclared_outputs_dir is set,'
- 'trace_file_path cannot be an absolute path (%s)'
- %self._trace_file_path)
- outputs_dir = os.environ.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR)
- self._trace_file_path = os.path.join(outputs_dir,
- self._trace_file_path)
+ raise ValueError('Cannot not use --%s and --%s at the same time'
+ %(_FLAG_NAME_TRACE_DIR,
+ _FLAG_NAME_USE_TEST_UNDECLARED_OUTPUTS_DIR))
+ if TensorTracer.use_test_undeclared_outputs_dir():
+ self._trace_dir = os.environ.get(_TEST_UNDECLARED_OUTPUTS_DIR_ENV_VAR)
def _set_report_file(self):
"""Sets the path of the output report file."""
@@ -660,6 +737,25 @@
return True
return False
+ def _use_tensor_values_cache(self):
+ """Returns True if immediate tensors should be first saved to a cache."""
+
+ if self._trace_mode not in set([_TRACE_MODE_NAN_INF,
+ _TRACE_MODE_NORM, _TRACE_MODE_MAX_ABS]):
+ return False
+ if self._trace_dir and _trace_files_need_precreated(self._trace_dir):
+ return True
+ if TensorTracer.use_compact_trace():
+ return True
+ return False
+
+ def _save_tensor_value_to_cache_op(self, graph, cache_idx, updates):
+ """Returns an Op that will save the given updates to an entry in the cache."""
+
+ cache = _get_tensor_values_cache(graph)
+ indices = constant_op.constant([cache_idx])
+ return state_ops.scatter_update(cache, indices, updates).op
+
def _write_report(self, content):
"""Writes the given content to the report."""
@@ -678,6 +774,9 @@
self._write_report('%s %s\n'%(_FIELD_NAME_TRACE_MODE, self._trace_mode))
self._write_report('%s %s\n'%(_FIELD_NAME_SUBMODE, self._submode))
self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS, self._num_replicas))
+ self._write_report('%s %s\n'%(_FIELD_NAME_NUM_REPLICAS_PER_HOST,
+ self._num_replicas_per_host))
+ self._write_report('%s %s\n'%(_FIELD_NAME_NUM_HOSTS, self._num_hosts))
self._write_report('%s %s\n'%(_MARKER_SECTION_END, _SECTION_NAME_CONFIG))
def _write_reason_section(self):
@@ -724,6 +823,20 @@
self._write_report('%s %s\n'%(_MARKER_SECTION_END,
_SECTION_NAME_TENSOR_LIST))
+ def _write_cache_index_map_section(self):
+ """Writes the mapping from cache index to tensor index to the report."""
+
+ self._write_report('%s %s\n'%(_MARKER_SECTION_BEGIN,
+ _SECTION_NAME_CACHE_INDEX_MAP))
+ self._write_report('%s %d\n'%(_FIELD_NAME_NUM_CACHE_INDICES,
+ len(self._cache_idx_to_tensor_idx)))
+ for cache_idx in range(0, len(self._cache_idx_to_tensor_idx)):
+ tensor_idx = self._cache_idx_to_tensor_idx[cache_idx]
+ line = '%d %d\n'%(cache_idx, tensor_idx)
+ self._write_report(line)
+ self._write_report('%s %s\n'%(_MARKER_SECTION_END,
+ _SECTION_NAME_CACHE_INDEX_MAP))
+
def _write_graph_section(self, succeed, sorted_or_cycle):
"""Writes the graph section of the report."""
@@ -750,11 +863,14 @@
"""Trace function for detecting any NaN/Inf in the tensor."""
if tensor.dtype.is_floating:
- output_tensor = math_ops.reduce_any(
+ mask = math_ops.reduce_any(
gen_math_ops.logical_or(
gen_math_ops.is_nan(tensor), gen_math_ops.is_inf(tensor)))
+ output_tensor = control_flow_ops.cond(mask,
+ lambda: constant_op.constant(1.0),
+ lambda: constant_op.constant(0.0))
else:
- output_tensor = constant_op.constant(False)
+ output_tensor = constant_op.constant(0.0)
# The shape has to be 1. Set it if it does not have the information.
output_tensor = array_ops.reshape(output_tensor, [1])
return output_tensor
@@ -826,8 +942,9 @@
else:
msg = '"%s"'%tensor_name
- if self._trace_file_path:
- output_stream = _OUTPUT_STREAM_ESCAPE + self._trace_file_path
+ if self._trace_dir:
+ output_path = os.path.join(self._trace_dir, _TRACE_FILE_NAME)
+ output_stream = _OUTPUT_STREAM_ESCAPE + output_path
else:
output_stream = sys.stderr
print_op = logging_ops.print_v2(msg, array_ops.shape(output_tensor),
@@ -954,6 +1071,7 @@
def _filter_execution_path_operations(self, operations, fetches):
"""Returns the set of ops in the execution path to compute given fetches."""
+
# If no fetch provided, then return all operations.
if fetches is None:
return set(operations)
@@ -986,17 +1104,85 @@
traverse_stack.append(input_op)
return execution_path_operations
- def _pre_tracing(self, graph):
+ def _determine_traced_tensors(self, graph, fetches):
+ """Determines the tensors that will be traced."""
+
+ self._traced_tensorname_to_cache_idx_map = {}
+ self._cache_idx_to_tensor_idx = []
+ operations = graph.get_operations()
+ # Filter out the operations that won't be executed.
+ # if fetches=None, then ops_in_exec_path = set(operations)
+ ops_in_exec_path = self._filter_execution_path_operations(operations,
+ fetches)
+ checkpoint_operations = self._get_checkpoints(graph)
+ for op_id, op in enumerate(operations):
+ if checkpoint_operations and op.name not in checkpoint_operations:
+ continue
+ user_included = self._is_user_included_op(op)
+ user_excluded = self._is_user_excluded_op(op)
+ in_exec_path = op in ops_in_exec_path
+ if self._skip_op(op_id, op, user_included, user_excluded, in_exec_path):
+ continue
+ for i in range(len(op.outputs)):
+ out_tensor = op.outputs[i]
+ if self._skip_tensor(op_id, out_tensor, user_included,
+ user_excluded):
+ continue
+ tensor_name = out_tensor.name
+ if tensor_name in self._traced_tensorname_to_cache_idx_map:
+ raise ValueError(
+ 'Tensor name %s should not be already in '
+ 'traced_tensorname_to_cache_idx_map'%tensor_name)
+ if tensor_name not in self._tensorname_idx_map:
+ raise ValueError(
+ 'Tensor name %s is not in the tensorname_idx_map'%tensor_name)
+ tensor_idx = self._tensorname_idx_map[tensor_name]
+ cache_idx = len(self._traced_tensorname_to_cache_idx_map)
+ self._traced_tensorname_to_cache_idx_map[tensor_name] = cache_idx
+ self._cache_idx_to_tensor_idx.append(tensor_idx)
+ if len(self._traced_tensorname_to_cache_idx_map) != len(
+ self._cache_idx_to_tensor_idx):
+ raise RuntimeError('len(self._traced_tensorname_to_cache_idx_map) != '
+ 'len(self._cache_idx_to_tensor_idx')
+
+ def _check_trace_files(self):
+ """Checks if any requirements for trace files are satisfied."""
+
+ if not self._trace_dir:
+ # traces will be written to stderr. No need to check trace files.
+ return
+ if _trace_files_need_precreated(self._trace_dir):
+ for replica_id in range(0, self._num_replicas):
+ trace_file_path = os.path.join(
+ self._trace_dir,
+ _COMPACT_TRACE_FILE_PREFIX) + '%d'%replica_id
+ if not gfile.Exists(trace_file_path):
+ raise RuntimeError(
+ '%s must be pre-created with the '
+ 'appropriate properties.'%trace_file_path)
+ else:
+ if not gfile.Exists(self._trace_dir):
+ gfile.MkDir(self._trace_dir)
+ if not gfile.Exists(self._trace_dir):
+ raise RuntimeError('Failed to create %s'%self._trace_dir)
+
+ def _pre_tracing(self, graph, fetches):
"""Work needs to be done prior to TPU or CPU tracing."""
+ self._check_trace_files()
operations = graph.get_operations()
(opname_idx_map, tensor_list, self._tensorname_idx_map) = (
TensorTracer._make_op_and_tensor_maps(operations))
self._write_config_section()
self._write_op_list_section(operations)
self._write_tensor_list_section(tensor_list, opname_idx_map)
+ self._determine_traced_tensors(graph, fetches)
+ self._write_cache_index_map_section()
# Does the topological sort before adding any nodes to the graph.
(succeed, sorted_or_cycle) = TensorTracer.topological_sort(graph)
+ if self._use_tensor_values_cache():
+ _create_tensor_values_cache(graph,
+ len(self._cache_idx_to_tensor_idx))
return (operations, succeed, sorted_or_cycle)
def _post_tracing(self, succeed, sorted_or_cycle):
@@ -1027,15 +1213,118 @@
_TENSOR_TRACER_CHECKPOINT))
return checkpoint_operations
- def trace_tpu(self, graph, result_tensor, num_replicas=None, fetches=None):
+ def _generate_flush_cache_op(self, graph, start_replica, on_tpu):
+ """Generates an Op that will flush the cache to file.
+
+ Args:
+ graph: the graph of Ops
+ start_replica: the ID of the first replica being flushed by this Op.
+ on_tpu: if the graph is executed on TPU.
+
+ Returns:
+ The Op to flush the cache to file.
+ """
+ def _make_flush_fun(replica_id):
+ """Makes a function for flushing the cache for the given replica."""
+
+ def _fun():
+ """A function that flushes the cache to a file."""
+
+ def _flush_fun(cache):
+ """Flushes the cache to a file."""
+
+ if isinstance(replica_id, str):
+ replica_id_str = replica_id
+ else:
+ replica_id_str = '%d'%replica_id
+ output_path = os.path.join(self._trace_dir,
+ _COMPACT_TRACE_FILE_PREFIX) \
+ + replica_id_str
+ output_stream = _OUTPUT_STREAM_ESCAPE + output_path
+ new_step_line = _REPLICA_ID_TAG + replica_id_str
+ print_op = logging_ops.print_v2(
+ new_step_line, '\n',
+ cache, '\n',
+ summarize=-1,
+ output_stream=output_stream)
+ with ops.control_dependencies([print_op]):
+ return constant_op.constant(0).op
+
+ cache = _get_tensor_values_cache(graph)
+ if on_tpu:
+ flush_op = tpu.outside_compilation(_flush_fun, cache.value())
+ else:
+ flush_op = _flush_fun(cache.value())
+ with ops.control_dependencies([flush_op]):
+ reset_value = constant_op.constant(_COMPACT_TRACE_ENTRY_INIT_VALUE,
+ dtype=cache.dtype,
+ shape=cache.shape)
+ assign_op = state_ops.assign(cache, reset_value).op
+ with ops.control_dependencies([assign_op]):
+ return flush_op.outputs[0]
+
+ return _fun
+
+ def _f(replica_id):
+ return _make_flush_fun(replica_id)
+ def _eq(x):
+ return math_ops.equal(x, self._replica_id)
+ def _do_nothing():
+ return constant_op.constant(0)
+
+ return control_flow_ops.case({\
+ _eq(start_replica): _f(start_replica), \
+ _eq(start_replica+1): _f(start_replica+1), \
+ _eq(start_replica+2): _f(start_replica+2), \
+ _eq(start_replica+3): _f(start_replica+3), \
+ _eq(start_replica+4): _f(start_replica+4), \
+ _eq(start_replica+5): _f(start_replica+5), \
+ _eq(start_replica+6): _f(start_replica+6), \
+ _eq(start_replica+7): _f(start_replica+7), \
+ },
+ default=_do_nothing,
+ exclusive=True).op
+
+ def _flush_tensor_values_cache(self, graph, result_tensor, train_op, on_tpu):
+ """Flushes the intermediate tensor values in the graph to the cache.
+
+ Args:
+ graph: the graph of Ops
+ result_tensor: a result tensor of evaluating the graph.
+ train_op: the training op.
+ on_tpu: if the graph is executed on TPU.
+
+ Returns:
+ An identical copy of result tensor.
+ """
+
+ train_op_list = []
+ if train_op is not None:
+ train_op_list.append(train_op)
+ with ops.control_dependencies(train_op_list):
+ flush_cache_op_list = []
+ for host in range(self._num_hosts):
+ start_replica = host * 8
+ flush_op = self._generate_flush_cache_op(graph, start_replica, on_tpu)
+ flush_cache_op_list.append(flush_op)
+ with ops.control_dependencies(flush_cache_op_list):
+ return array_ops.identity(result_tensor)
+
+ def trace_tpu(self, graph,
+ result_tensor,
+ train_op,
+ num_replicas=None,
+ num_replicas_per_host=None,
+ num_hosts=None):
"""Traces the tensors generated by TPU Ops in a TF graph.
Args:
graph: the graph of Ops executed on the TPU.
result_tensor: a result tensor of evaluating the graph.
+ train_op: the training op.
num_replicas: number of replicas used on the TPU.
- fetches: the list of fetches given to session.run, used to determine the
- ops in execution path. If None, the whole graph will be traced.
+ num_replicas_per_host: number of replicas per TPU host.
+ num_hosts: total number of TPU hosts.
Returns:
A tuple (result_tensor_copy, tracing_ops), where:
@@ -1045,6 +1334,9 @@
should pose control dependencies upon these
Ops so that they will be executed when the
graph is evaluated.
+
+ Raises:
+ RuntimeError: If num_replicas_per_host > 8.
"""
def _cast_unsupported_dtypes(tensor):
@@ -1060,88 +1352,130 @@
return tensor
self._device_type = _DEVICE_TYPE_TPU
- TensorTracer.check_device_type(self._device_type)
- result_tensor_copy = self._add_replica_id_to_graph(num_replicas,
- result_tensor)
- (operations, succeed, sorted_or_cycle) = self._pre_tracing(graph)
- # Filter out the operations that won't be executed.
- # if fetches=None, then ops_in_exec_path = set(operations)
- ops_in_exec_path = self._filter_execution_path_operations(operations,
- fetches)
- tracing_ops = []
- checkpoint_operations = self._get_checkpoints(graph)
+ self._num_replicas = num_replicas
+ self._num_replicas_per_host = num_replicas_per_host
+ self._num_hosts = num_hosts
+ if self._num_replicas_per_host > 8:
+ # Checks for the assumption in _generate_flush_cache_op().
+ raise RuntimeError(
+ 'num_replicas_per_host (%d) is '
+ 'greater than 8'%self._num_replicas_per_host)
- for op_id, op in enumerate(operations):
- if checkpoint_operations and op.name not in checkpoint_operations:
- continue
- user_included = self._is_user_included_op(op)
- user_excluded = self._is_user_excluded_op(op)
- in_exec_path = op in ops_in_exec_path
- if self._skip_op(op_id, op, user_included, user_excluded, in_exec_path):
- continue
+ TensorTracer.check_device_type(self._device_type)
+ result_tensor_copy = self._add_replica_id_to_graph(result_tensor)
+ fetches = _set_fetches(result_tensor, train_op)
+ (operations, succeed, sorted_or_cycle) = self._pre_tracing(graph, fetches)
+
+ tracing_ops = []
+ for op in operations:
for i in range(len(op.outputs)):
out_tensor = op.outputs[i]
- if self._skip_tensor(op_id, out_tensor, user_included,
- user_excluded):
+ tensor_name = out_tensor.name
+ if tensor_name not in self._traced_tensorname_to_cache_idx_map:
continue
# Create the list of consumers before calling _preprocess_traced_tensor.
# Otherwise, adding control input below, will introduce a cycle in the
# graph.
consumers = out_tensor.consumers()
- tensor_name = out_tensor.name
+ if not consumers:
+ continue
processed_out_tensor = self._preprocess_traced_tensor(out_tensor)
processed_out_tensor = _cast_unsupported_dtypes(processed_out_tensor)
- trace_op = tpu.outside_compilation(
- self._make_tensor_trace_fun(tensor_name), processed_out_tensor)
- if consumers:
- for consumer_op in consumers:
- # pylint: disable=protected-access
- consumer_op._add_control_input(trace_op)
- # pylint: enable=protected-access
+ if self._use_tensor_values_cache():
+ cache_idx = self._traced_tensorname_to_cache_idx_map[tensor_name]
+ trace_op = self._save_tensor_value_to_cache_op(graph,
+ cache_idx,
+ processed_out_tensor)
else:
- # if there is no consumer, we will add the control dependence later
- # when we add the control dependency to the output operations.
- tracing_ops.append(trace_op)
+ trace_op = tpu.outside_compilation(
+ self._make_tensor_trace_fun(tensor_name), processed_out_tensor)
+ for consumer_op in consumers:
+ # pylint: disable=protected-access
+ consumer_op._add_control_input(trace_op)
+ # pylint: enable=protected-access
+ if self._use_tensor_values_cache():
+ result_tensor_final = self._flush_tensor_values_cache(graph,
+ result_tensor_copy,
+ train_op,
+ on_tpu=True)
+ else:
+ result_tensor_final = result_tensor_copy
self._post_tracing(succeed, sorted_or_cycle)
- return (result_tensor_copy, tracing_ops)
+ return (result_tensor_final, tracing_ops)
- def trace_cpu(self, graph):
+ def _generate_cpu_result(self, result_tensor, train_op, graph):
+ """Generates the final CPU result."""
+
+ if self._use_tensor_values_cache():
+ result_tensor_final = self._flush_tensor_values_cache(graph,
+ result_tensor,
+ train_op,
+ on_tpu=False)
+ else:
+ result_tensor_final = array_ops.identity(result_tensor)
+ return result_tensor_final
+
+ def trace_cpu(self, graph, result_tensor, train_op):
"""Traces the tensors generated by CPU Ops in a TF graph.
Args:
graph: the graph of Ops executed on the CPU.
+ result_tensor: a result tensor of evaluating the graph.
+ train_op: the training op.
Returns:
- tracing_calls: a map from keys to trace calls.
+ A pair (final_result_tensor, tracing_calls) where:
+ final_result_tensor: an identical copy of result_tensor.
+ tracing_calls: a map from keys to trace calls.
A key is constructed from an Op's name.
A trace call consists of a function and a tensor (
the function will be invoked with the tensor).
"""
+ if result_tensor is None:
+ raise ValueError(
+ 'The result_tensor passed to trace_cpu should not be None')
+
self._device_type = _DEVICE_TYPE_CPU
TensorTracer.check_device_type(self._device_type)
self._num_replicas = 1
+ self._num_replicas_per_host = 1
+ self._num_hosts = 1
self._replica_id = 0
- (operations, succeed, sorted_or_cycle) = self._pre_tracing(graph)
- tracing_calls = {}
- checkpoint_operations = self._get_checkpoints(graph)
+ fetches = _set_fetches(result_tensor, train_op)
+ (operations, succeed, sorted_or_cycle) = self._pre_tracing(graph, fetches)
- for op_id, op in enumerate(operations):
- if checkpoint_operations and op.name not in checkpoint_operations:
- continue
- user_included = self._is_user_included_op(op)
- user_excluded = self._is_user_excluded_op(op)
- if self._skip_op(op_id, op, user_included, user_excluded):
- continue
+ tracing_calls = {}
+ for op in operations:
for i in range(len(op.outputs)):
out_tensor = op.outputs[i]
- if self._skip_tensor(op_id, out_tensor, user_included,
- user_excluded):
+ tensor_name = out_tensor.name
+ if tensor_name not in self._traced_tensorname_to_cache_idx_map:
+ continue
+ # Create the list of consumers before calling _preprocess_traced_tensor.
+ # Otherwise, adding control input below, will introduce a cycle in the
+ # graph.
+ consumers = out_tensor.consumers()
+ if not consumers:
continue
processed_out_tensor = self._preprocess_traced_tensor(out_tensor)
- trace_fun = self._make_tensor_trace_fun(out_tensor.name)
- trace_call = (trace_fun, [processed_out_tensor])
- trace_call_key = 'tensor_tracing_cpu-%s:%d'%(op.name, i)
- tracing_calls[trace_call_key] = trace_call
+ if self._use_tensor_values_cache():
+ cache_idx = self._traced_tensorname_to_cache_idx_map[tensor_name]
+ trace_op = self._save_tensor_value_to_cache_op(graph,
+ cache_idx,
+ processed_out_tensor)
+ for consumer_op in consumers:
+ # pylint: disable=protected-access
+ consumer_op._add_control_input(trace_op)
+ # pylint: enable=protected-access
+ else:
+ trace_fun = self._make_tensor_trace_fun(tensor_name)
+ trace_call = (trace_fun, [processed_out_tensor])
+ trace_call_key = 'tensor_tracing_cpu-%s:%d'%(op.name, i)
+ tracing_calls[trace_call_key] = trace_call
+
self._post_tracing(succeed, sorted_or_cycle)
- return tracing_calls
+ final_result_tensor = self._generate_cpu_result(result_tensor,
+ train_op,
+ graph)
+ return (final_result_tensor, tracing_calls)
diff --git a/tensorflow/contrib/tpu/python/tpu/topology.py b/tensorflow/contrib/tpu/python/tpu/topology.py
index 6ae718c..00ee21e 100644
--- a/tensorflow/contrib/tpu/python/tpu/topology.py
+++ b/tensorflow/contrib/tpu/python/tpu/topology.py
@@ -21,7 +21,7 @@
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
-from tensorflow.contrib.tpu.proto import topology_pb2
+from tensorflow.core.protobuf.tpu import topology_pb2
def _tpu_device_name(job, task, device):
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py
index de2bfd4..673129b 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu.py
@@ -24,11 +24,10 @@
from tensorflow.contrib.compiler import xla
from tensorflow.contrib.framework.python.framework import experimental
-from tensorflow.contrib.tpu.proto import dynamic_padding_pb2 as dynamic_padding
from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.tpu.python.tpu import tpu_function
-
from tensorflow.core.framework import attr_value_pb2
+from tensorflow.core.protobuf.tpu import dynamic_padding_pb2 as dynamic_padding
from tensorflow.python.compat import compat as api_compat
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py b/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py
index 0e4597b..8731b93 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_embedding.py
@@ -26,9 +26,10 @@
from tensorflow.contrib.framework.python.framework import experimental
from tensorflow.contrib.tpu.ops import gen_tpu_ops
-from tensorflow.contrib.tpu.proto import tpu_embedding_configuration_pb2 as elc
from tensorflow.contrib.tpu.python.ops import tpu_ops
from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
+from tensorflow.core.protobuf.tpu import optimization_parameters_pb2
+from tensorflow.core.protobuf.tpu import tpu_embedding_configuration_pb2 as elc
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
@@ -473,6 +474,11 @@
table_descriptor.optimization_parameters.learning_rate.constant = (
self._optimization_parameters.learning_rate)
+ table_descriptor.optimization_parameters.gradient_accumulation_status = (
+ optimization_parameters_pb2.GradientAccumulationStatus.ENABLED
+ if self._optimization_parameters.use_gradient_accumulation else
+ optimization_parameters_pb2.GradientAccumulationStatus.DISABLED)
+ # For compatibility with old TPU workers.
table_descriptor.optimization_parameters.use_gradient_accumulation = (
self._optimization_parameters.use_gradient_accumulation)
self._optimizer_handler.set_optimization_parameters(table_descriptor)
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 10f3451..a372d44 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -31,9 +31,8 @@
from six.moves import queue as Queue # pylint: disable=redefined-builtin
from six.moves import xrange # pylint: disable=redefined-builtin
-from tensorflow.contrib.tpu.ops import gen_tpu_ordinal_selector_op
-from tensorflow.contrib.tpu.proto import compilation_result_pb2 as tpu_compilation_result
from tensorflow.contrib.tpu.python.ops import tpu_ops
+from tensorflow.contrib.tpu.python.ops import tpu_ordinal_selector_op
from tensorflow.contrib.tpu.python.tpu import _tpu_estimator_embedding
from tensorflow.contrib.tpu.python.tpu import error_handling
from tensorflow.contrib.tpu.python.tpu import functional as tpu_functional
@@ -51,6 +50,7 @@
from tensorflow.core.framework import variable_pb2
from tensorflow.core.framework.summary_pb2 import Summary
from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf.tpu import compilation_result_pb2 as tpu_compilation_result
from tensorflow.python.client import session as tf_session
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest as data_nest
@@ -354,9 +354,12 @@
hooks = None
if self.host_call is not None:
hooks = [_OutfeedHostCallHook(host_call_ret['host_call'])]
- if tensor_tracer.TensorTracer.is_enabled():
+ loss = self.loss
+ if tensor_tracer.TensorTracer.is_enabled() \
+ and self.train_op is not None:
tt = tensor_tracer.TensorTracer()
- tracing_calls = tt.trace_cpu(ops.get_default_graph())
+ (loss, tracing_calls) = tt.trace_cpu(ops.get_default_graph(),
+ loss, self.train_op)
tracing_call_ret = _OutfeedHostCall.create_cpu_hostcall(tracing_calls)
tracing_functions = tracing_call_ret.values()
if tracing_functions:
@@ -369,7 +372,7 @@
return model_fn_lib.EstimatorSpec(
mode=self.mode,
predictions=self.predictions,
- loss=self.loss,
+ loss=loss,
train_op=self.train_op,
eval_metric_ops=eval_metric_ops,
export_outputs=self.export_outputs,
@@ -464,7 +467,11 @@
self._feed_error = None
self._finished = False
- self._should_initialize_tpu = True
+ # When using model parallelism, the TPU is pre-initialized at startup to
+ # fetch mesh information. We skip re-initializing it here to avoid
+ # suspected issues due to the mesh layout changing on the second
+ # initialization.
+ self._should_initialize_tpu = not ctx.model_parallelism_enabled
self._tpu_compile_op = tpu_compile_op
def begin(self):
@@ -1370,7 +1377,7 @@
return tpu_functional.TPUPartitionedCall(
args=tpu_subgraph.captured_inputs,
- device_ordinal=gen_tpu_ordinal_selector_op.tpu_ordinal_selector(),
+ device_ordinal=tpu_ordinal_selector_op.tpu_ordinal_selector(),
Tout=[o.type for o in tpu_subgraph.definition.signature.output_arg],
f=tpu_subgraph)
else:
@@ -1450,9 +1457,11 @@
tracing_ops = []
if tensor_tracer.TensorTracer.is_enabled():
tt = tensor_tracer.TensorTracer()
- loss, tracing_ops = tt.trace_tpu(ops.get_default_graph(), loss,
+ loss, tracing_ops = tt.trace_tpu(ops.get_default_graph(),
+ loss, train_op,
self._ctx.num_replicas,
- fetches=[loss, train_op])
+ self._ctx.num_of_replicas_per_host,
+ self._ctx.num_hosts)
if self._ctx.embedding_config is None:
apply_sparse_grads = []
diff --git a/tensorflow/contrib/tpu/utils/BUILD b/tensorflow/contrib/tpu/utils/BUILD
index c27b737..5cbed40 100644
--- a/tensorflow/contrib/tpu/utils/BUILD
+++ b/tensorflow/contrib/tpu/utils/BUILD
@@ -8,9 +8,9 @@
hdrs = ["tpu_embedding_optimization_parameters_utils.h"],
visibility = ["//visibility:public"],
deps = [
- "//tensorflow/contrib/tpu/proto:optimization_parameters_proto_cc",
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core:lib_proto_parsing",
+ "//tensorflow/core/protobuf/tpu:optimization_parameters_proto_cc",
"@com_google_absl//absl/base",
],
)
@@ -21,10 +21,10 @@
hdrs = ["tpu_embedding_output_layout_utils.h"],
visibility = ["//visibility:public"],
deps = [
- "//tensorflow/contrib/tpu/proto:tpu_embedding_configuration_proto_cc",
- "//tensorflow/contrib/tpu/proto:tpu_embedding_output_layout_proto_cc",
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core:lib_proto_parsing",
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/protobuf/tpu:tpu_embedding_configuration_proto_cc",
+ "//tensorflow/core/protobuf/tpu:tpu_embedding_output_layout_proto_cc",
],
)
diff --git a/tensorflow/contrib/tpu/utils/tpu_embedding_optimization_parameters_utils.cc b/tensorflow/contrib/tpu/utils/tpu_embedding_optimization_parameters_utils.cc
index d98e0b7..d1df7e7 100644
--- a/tensorflow/contrib/tpu/utils/tpu_embedding_optimization_parameters_utils.cc
+++ b/tensorflow/contrib/tpu/utils/tpu_embedding_optimization_parameters_utils.cc
@@ -135,7 +135,7 @@
}
namespace {
// Make a normal state variable specification. Please refer to
-// //third_party/tensorflow/contrib/tpu/proto/optimization_parameters.proto
+// //tensorflow/core/protobuf/tpu/optimization_parameters.proto
// (StateVariableSpecification message) for instructions on how to set the
// padding_initial_value field.
StateVariableSpecification MakeStandardStateVariableSpecification(
diff --git a/tensorflow/contrib/tpu/utils/tpu_embedding_optimization_parameters_utils.h b/tensorflow/contrib/tpu/utils/tpu_embedding_optimization_parameters_utils.h
index 81d5026..7a7833b 100644
--- a/tensorflow/contrib/tpu/utils/tpu_embedding_optimization_parameters_utils.h
+++ b/tensorflow/contrib/tpu/utils/tpu_embedding_optimization_parameters_utils.h
@@ -18,8 +18,8 @@
#include <string>
#include "absl/base/casts.h"
-#include "tensorflow/contrib/tpu/proto/optimization_parameters.pb.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/protobuf/tpu/optimization_parameters.pb.h"
namespace tensorflow {
namespace tpu {
diff --git a/tensorflow/contrib/tpu/utils/tpu_embedding_output_layout_utils.cc b/tensorflow/contrib/tpu/utils/tpu_embedding_output_layout_utils.cc
index 8480ec4..e65abe3 100644
--- a/tensorflow/contrib/tpu/utils/tpu_embedding_output_layout_utils.cc
+++ b/tensorflow/contrib/tpu/utils/tpu_embedding_output_layout_utils.cc
@@ -14,8 +14,8 @@
==============================================================================*/
#include "tensorflow/contrib/tpu/utils/tpu_embedding_output_layout_utils.h"
-#include "tensorflow/contrib/tpu/proto/tpu_embedding_output_layout.pb.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/protobuf/tpu/tpu_embedding_output_layout.pb.h"
namespace tensorflow {
namespace tpu {
diff --git a/tensorflow/contrib/tpu/utils/tpu_embedding_output_layout_utils.h b/tensorflow/contrib/tpu/utils/tpu_embedding_output_layout_utils.h
index c10fbee..1a04c7b 100644
--- a/tensorflow/contrib/tpu/utils/tpu_embedding_output_layout_utils.h
+++ b/tensorflow/contrib/tpu/utils/tpu_embedding_output_layout_utils.h
@@ -16,9 +16,9 @@
#ifndef TENSORFLOW_CONTRIB_TPU_UTILS_TPU_EMBEDDING_OUTPUT_LAYOUT_UTILS_H_
#define TENSORFLOW_CONTRIB_TPU_UTILS_TPU_EMBEDDING_OUTPUT_LAYOUT_UTILS_H_
-#include "tensorflow/contrib/tpu/proto/tpu_embedding_configuration.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/protobuf/tpu/tpu_embedding_configuration.pb.h"
namespace tensorflow {
namespace tpu {
diff --git a/tensorflow/contrib/training/BUILD b/tensorflow/contrib/training/BUILD
index f6427ae..5bc4c3b 100644
--- a/tensorflow/contrib/training/BUILD
+++ b/tensorflow/contrib/training/BUILD
@@ -264,9 +264,9 @@
py_test(
name = "training_test",
- size = "large",
+ size = "medium",
srcs = ["python/training/training_test.py"],
- shard_count = 3,
+ shard_count = 8,
srcs_version = "PY2AND3",
tags = ["notsan"],
deps = [
diff --git a/tensorflow/contrib/training/python/training/training.py b/tensorflow/contrib/training/python/training/training.py
index fc6e38a..4ceb6e9 100644
--- a/tensorflow/contrib/training/python/training/training.py
+++ b/tensorflow/contrib/training/python/training/training.py
@@ -244,7 +244,6 @@
from __future__ import division
from __future__ import print_function
-from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
@@ -354,11 +353,11 @@
raise ValueError('Requested multiple of `None` gradient.')
if isinstance(grad, ops.IndexedSlices):
- tmp = grad.values * constant_op.constant(
+ tmp = grad.values * ops.convert_to_tensor(
gradient_multipliers[key], dtype=grad.dtype)
grad = ops.IndexedSlices(tmp, grad.indices, grad.dense_shape)
else:
- grad *= constant_op.constant(
+ grad *= ops.convert_to_tensor(
gradient_multipliers[key], dtype=grad.dtype)
multiplied_grads_and_vars.append((grad, var))
return multiplied_grads_and_vars
@@ -433,7 +432,7 @@
else:
# Make sure that variables_to_train are in tf.trainable_variables()
for v in variables_to_train:
- assert v in tf_variables.trainable_variables()
+ assert v.trainable or v in tf_variables.trainable_variables()
assert variables_to_train
diff --git a/tensorflow/contrib/verbs/verbs_server_lib.cc b/tensorflow/contrib/verbs/verbs_server_lib.cc
index 19ef109..d07fd5a 100644
--- a/tensorflow/contrib/verbs/verbs_server_lib.cc
+++ b/tensorflow/contrib/verbs/verbs_server_lib.cc
@@ -81,7 +81,10 @@
Status VerbsServer::Init(ServiceInitFunction service_func,
RendezvousMgrCreationFunction rendezvous_mgr_func) {
std::call_once(reg_mem_visitors_call, []() { RdmaMgr::RegMemVisitors(); });
- Status s = GrpcServer::Init(service_func, rendezvous_mgr_func);
+ GrpcServerOptions opts;
+ opts.service_func = service_func;
+ opts.rendezvous_mgr_func = rendezvous_mgr_func;
+ Status s = GrpcServer::Init(opts);
{
mutex_lock l(mu_);
CHECK_EQ(verbs_state_, DISCONNECTED);
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 3d92a83..2f066d8 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -77,6 +77,7 @@
"//tensorflow:tensorflow.bzl",
"cc_header_only_library",
"if_android",
+ "if_emscripten",
"if_ios",
"if_linux_x86_64",
"if_mobile",
@@ -87,10 +88,12 @@
"tf_copts",
"tf_cuda_library",
"tf_features_nomodules_if_android",
+ "tf_features_nomodules_if_emscripten",
"tf_gen_op_libs",
"tf_generate_proto_text_sources",
"tf_genrule_cmd_append_to_srcs",
"tf_opts_nortti_if_android",
+ "tf_opts_nortti_if_emscripten",
"transitive_hdrs",
)
load("//tensorflow:tensorflow.bzl", "tf_cc_test_mkl")
@@ -1143,6 +1146,13 @@
tf_gen_op_libs(
op_lib_names = [
+ "mkl_array_ops",
+ ],
+ deps = [":protos_all_cc"],
+)
+
+tf_gen_op_libs(
+ op_lib_names = [
"audio_ops",
],
deps = [":lib"],
@@ -1280,7 +1290,10 @@
":training_ops_op_lib",
":user_ops_op_lib",
":word2vec_ops",
- ] + if_mkl([":mkl_nn_ops_op_lib"]) + tf_additional_cloud_op_deps(),
+ ] + if_mkl([
+ ":mkl_array_ops_op_lib",
+ ":mkl_nn_ops_op_lib",
+ ]) + tf_additional_cloud_op_deps(),
alwayslink = 1,
)
@@ -1770,6 +1783,29 @@
],
)
+cc_library(
+ name = "emscripten_tensorflow_lib_lite_nortti_lite_protos_no_runtime",
+ srcs = if_emscripten(["//tensorflow/core:mobile_srcs_no_runtime"]),
+ copts = ["-DSUPPORT_SELECTIVE_REGISTRATION"] + tf_opts_nortti_if_emscripten(),
+ defines = ["TENSORFLOW_LITE_PROTOS"],
+ linkopts = ["-lz"],
+ tags = [
+ "manual",
+ "notap",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":emscripten_proto_lib_no_rtti_lite_runtime",
+ ":mobile_additional_lib_deps",
+ ":stats_calculator_portable",
+ "//third_party/eigen3",
+ "@double_conversion//:double-conversion",
+ "@nsync//:nsync_cpp",
+ "@zlib_archive//:zlib",
+ ],
+ alwayslink = 1,
+)
+
# Native library support for iOS applications.
#
# bazel build --config=ios_x86_64 \
@@ -2274,6 +2310,8 @@
":lib_proto_parsing",
":abi",
":core_stringpiece",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
"//third_party/eigen3",
"//tensorflow/core/platform/default/build_config:platformlib",
"@snappy",
@@ -2653,7 +2691,6 @@
"example/**/*.cc",
"framework/**/*.cc",
"util/**/*.cc",
- ] + [
"graph/edgeset.cc",
"graph/graph.cc",
"graph/graph_def_builder.cc",
@@ -2898,6 +2935,7 @@
CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
"common_runtime/allocator_retry.h",
+ "common_runtime/shared_counter.h",
"common_runtime/base_collective_executor.h",
"common_runtime/bfc_allocator.h",
"common_runtime/hierarchical_tree_broadcaster.h",
@@ -3680,6 +3718,20 @@
)
tf_cc_test(
+ name = "lib_strings_proto_serialization_test",
+ srcs = ["lib/strings/proto_serialization_test.cc"],
+ deps = [
+ ":lib",
+ ":lib_internal",
+ ":lib_test_internal",
+ ":protos_all_cc",
+ ":test",
+ ":test_main",
+ "@com_google_absl//absl/memory",
+ ],
+)
+
+tf_cc_test(
name = "lib_random_weighted_picker_test",
size = "medium",
srcs = ["lib/random/weighted_picker_test.cc"],
@@ -4484,7 +4536,7 @@
"//tensorflow/cc:scope",
"//tensorflow/core/kernels:cwise_op",
"//third_party/eigen3",
- ],
+ ] + if_mkl([":mkl_array_ops_op_lib"]),
)
tf_cc_test(
@@ -5037,6 +5089,39 @@
# -----------------------------------------------------------------------------
# Google-internal targets go here (must be at the end).
+load("//tensorflow:tensorflow.bzl", "tf_portable_proto_library")
+
+genrule(
+ name = "emscripten_proto_config_lite_runtime",
+ outs = ["emscripten_proto_config_lite_runtime.asciipb"],
+ cmd = tf_genrule_cmd_append_to_srcs("optimize_mode:LITE_RUNTIME"),
+ visibility = ["//visibility:private"],
+)
+
+# We are keeping the "android" version of tf_android_core_proto_headers. All it does is
+# normalize CORE_PROTO_SRCS to generate valid output file names.
+tf_portable_proto_library(
+ name = "emscripten_proto_lib_no_rtti_lite_runtime",
+ config = ":emscripten_proto_config_lite_runtime",
+ copts = tf_opts_nortti_if_emscripten(),
+ features = tf_features_nomodules_if_emscripten(),
+ header_outs = tf_android_core_proto_headers(CORE_PROTO_SRCS) + ["//google/protobuf/any.proto.h"],
+ link_full_protobuf = False,
+ prefix_dir = "emscripten_proto_no_rtti",
+ proto_deps = [
+ ":protos_all_cc",
+ "@protobuf_archive//:protobuf",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+# There is currently no need for a full proto version of emscripten tf lib lite.
+alias(
+ name = "emscripten_lib_lite_no_runtime",
+ actual = "//tensorflow/core:emscripten_tensorflow_lib_lite_nortti_lite_protos_no_runtime",
+ visibility = ["//visibility:public"],
+)
+
alias(
name = "android_srcs_no_runtime",
actual = ":mobile_srcs_no_runtime",
diff --git a/tensorflow/core/api_def/base_api/api_def_DecodeProtoV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_DecodeProtoV2.pbtxt
index c8152f5..22c3524 100644
--- a/tensorflow/core/api_def/base_api/api_def_DecodeProtoV2.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_DecodeProtoV2.pbtxt
@@ -31,7 +31,8 @@
attr {
name: "field_names"
description: <<END
-List of strings containing proto field names.
+List of strings containing proto field names. An extension field can be decoded
+by using its full name, e.g. EXT_PACKAGE.EXT_FIELD_NAME.
END
}
attr {
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalRebatchDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalRebatchDataset.pbtxt
new file mode 100644
index 0000000..b845530
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalRebatchDataset.pbtxt
@@ -0,0 +1,23 @@
+op {
+ graph_op_name: "ExperimentalRebatchDataset"
+ visibility: HIDDEN
+ in_arg {
+ name: "input_dataset"
+ description: <<END
+A variant tensor representing the input dataset.
+END
+ }
+ in_arg {
+ name: "num_workers"
+ description: <<END
+A scalar representing the number of workers to distribute this batch across. As
+a result of this transformation the current batch size would end up being
+divided by this parameter.
+END
+ }
+ summary: "Creates a dataset that changes the batch size."
+ description: <<END
+Creates a dataset that changes the batch size of the dataset to current batch
+size // num_workers.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt
index d33a36c..d5643c8 100644
--- a/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt
@@ -17,7 +17,7 @@
summary: "Computes the maximum along segments of a tensor."
description: <<END
Read
-[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+[the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
for an explanation of segments.
Computes a tensor such that
@@ -29,5 +29,15 @@
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="https://www.tensorflow.org/images/SegmentMax.png" alt>
</div>
+
+For example:
+
+```
+c = tf.constant([[1,2,3,4], [4, 3, 2, 1], [5,6,7,8]])
+tf.segment_max(c, tf.constant([0, 0, 1]))
+# ==> [[4, 3, 3, 4],
+# [5, 6, 7, 8]]
+```
+
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt
index afdc39d..b03649a 100644
--- a/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt
@@ -17,7 +17,7 @@
summary: "Computes the mean along segments of a tensor."
description: <<END
Read
-[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+[the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
for an explanation of segments.
Computes a tensor such that
@@ -30,5 +30,15 @@
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="https://www.tensorflow.org/images/SegmentMean.png" alt>
</div>
+
+For example:
+
+```
+c = tf.constant([[1.0,2,3,4], [4, 3, 2, 1], [5,6,7,8]])
+tf.segment_mean(c, tf.constant([0, 0, 1]))
+# ==> [[2.5, 2.5, 2.5, 2.5],
+# [5, 6, 7, 8]]
+```
+
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt
index 026b5b3..6796678 100644
--- a/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt
@@ -17,7 +17,7 @@
summary: "Computes the minimum along segments of a tensor."
description: <<END
Read
-[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+[the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
for an explanation of segments.
Computes a tensor such that
@@ -29,5 +29,14 @@
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="https://www.tensorflow.org/images/SegmentMin.png" alt>
</div>
+
+For example:
+
+```
+c = tf.constant([[1,2,3,4], [4, 3, 2, 1], [5,6,7,8]])
+tf.segment_min(c, tf.constant([0, 0, 1]))
+# ==> [[1, 2, 2, 1],
+# [5, 6, 7, 8]]
+```
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt
index a168eed..10b368f 100644
--- a/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt
@@ -17,7 +17,7 @@
summary: "Computes the product along segments of a tensor."
description: <<END
Read
-[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+[the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
for an explanation of segments.
Computes a tensor such that
@@ -29,5 +29,15 @@
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="https://www.tensorflow.org/images/SegmentProd.png" alt>
</div>
+
+For example:
+
+```
+c = tf.constant([[1,2,3,4], [4, 3, 2, 1], [5,6,7,8]])
+tf.segment_prod(c, tf.constant([0, 0, 1]))
+# ==> [[4, 6, 6, 4],
+# [5, 6, 7, 8]]
+```
+
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt
index 876b860..487a6d1 100644
--- a/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt
@@ -17,7 +17,7 @@
summary: "Computes the sum along segments of a tensor."
description: <<END
Read
-[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+[the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
for an explanation of segments.
Computes a tensor such that
@@ -29,5 +29,15 @@
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="https://www.tensorflow.org/images/SegmentSum.png" alt>
</div>
+
+For example:
+
+```
+c = tf.constant([[1,2,3,4], [4, 3, 2, 1], [5,6,7,8]])
+tf.segment_sum(c, tf.constant([0, 0, 1]))
+# ==> [[5, 5, 5, 5],
+# [5, 6, 7, 8]]
+```
+
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_ShardDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ShardDataset.pbtxt
new file mode 100644
index 0000000..cd537e0
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ShardDataset.pbtxt
@@ -0,0 +1,17 @@
+op {
+ graph_op_name: "ShardDataset"
+ visibility: HIDDEN
+ in_arg {
+ name: "num_shards"
+ description: <<END
+An integer representing the number of shards operating in parallel.
+END
+ }
+ in_arg {
+ name: "index"
+ description: <<END
+An integer representing the current worker index.
+END
+ }
+ summary: "Creates a `Dataset` that includes only 1/`num_shards` of this dataset."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseSegmentMean.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseSegmentMean.pbtxt
index 138a636..0bbc078 100644
--- a/tensorflow/core/api_def/base_api/api_def_SparseSegmentMean.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SparseSegmentMean.pbtxt
@@ -21,9 +21,7 @@
}
summary: "Computes the mean along sparse segments of a tensor."
description: <<END
-Read
-[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
-for an explanation of segments.
+See `tf.sparse.segment_sum` for usage examples.
Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first
dimension, selecting a subset of dimension 0, specified by `indices`.
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseSegmentMeanWithNumSegments.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseSegmentMeanWithNumSegments.pbtxt
index b8073d8..65b2358 100644
--- a/tensorflow/core/api_def/base_api/api_def_SparseSegmentMeanWithNumSegments.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SparseSegmentMeanWithNumSegments.pbtxt
@@ -31,7 +31,7 @@
misisng, the `output` tensor at that position will be zeroed.
Read
-[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+[the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
for an explanation of segments.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtN.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtN.pbtxt
index 945bbdc..a28bd1a 100644
--- a/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtN.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtN.pbtxt
@@ -23,8 +23,7 @@
description: <<END
N is the size of the segment being reduced.
-Read
-[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
-for an explanation of segments.
+See `tf.sparse.segment_sum` for usage examples.
+
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtNWithNumSegments.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtNWithNumSegments.pbtxt
index ff328c8..8a5d2bb 100644
--- a/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtNWithNumSegments.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtNWithNumSegments.pbtxt
@@ -33,7 +33,7 @@
misisng, the `output` tensor at that position will be zeroed.
Read
-[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+[the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
for an explanation of segments.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseSegmentSum.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSum.pbtxt
index a68e146..d7494dc 100644
--- a/tensorflow/core/api_def/base_api/api_def_SparseSegmentSum.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSum.pbtxt
@@ -22,7 +22,7 @@
summary: "Computes the sum along sparse segments of a tensor."
description: <<END
Read
-[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+[the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
for an explanation of segments.
Like `SegmentSum`, but `segment_ids` can have rank less than `data`'s first
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseSegmentSumWithNumSegments.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSumWithNumSegments.pbtxt
index aa5c1fc..039ca9a 100644
--- a/tensorflow/core/api_def/base_api/api_def_SparseSegmentSumWithNumSegments.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSumWithNumSegments.pbtxt
@@ -31,7 +31,7 @@
misisng, the `output` tensor at that position will be zeroed.
Read
-[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+[the section on segmentation](https://tensorflow.org/api_docs/python/tf/sparse#Segmentation)
for an explanation of segments.
For example:
diff --git a/tensorflow/core/api_def/base_api/api_def_TensorListConcatV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_TensorListConcatV2.pbtxt
new file mode 100644
index 0000000..9b2af2c
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_TensorListConcatV2.pbtxt
@@ -0,0 +1,18 @@
+op {
+ graph_op_name: "TensorListConcatV2"
+ summary: "Concats all tensors in the list along the 0th dimension."
+ description: <<END
+Requires that all tensors have the same shape except the first dimension.
+
+input_handle: The input list.
+element_shape: The shape of the uninitialized elements in the list. If the first
+ dimension is not -1, it is assumed that all list elements have the same
+ leading dim.
+leading_dims: The list of leading dims of uninitialized list elements. Used if
+ the leading dim of input_handle.element_shape or the element_shape input arg
+ is not already set.
+tensor: The concated result.
+lengths: Output tensor containing sizes of the 0th dimension of tensors in the list, used for computing the gradient.
+
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt
index ed4a2bd..f282b9f 100644
--- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt
@@ -17,7 +17,7 @@
summary: "Computes the maximum along segments of a tensor."
description: <<END
Read
-[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+[the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
for an explanation of segments.
This operator is similar to the unsorted segment sum operator found
@@ -37,5 +37,15 @@
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="https://www.tensorflow.org/images/UnsortedSegmentMax.png" alt>
</div>
+
+For example:
+
+``` python
+c = tf.constant([[1,2,3,4], [5,6,7,8], [4,3,2,1]])
+tf.unsorted_segment_max(c, tf.constant([0, 1, 0]), num_segments=2)
+# ==> [[ 4, 3, 3, 4],
+# [5, 6, 7, 8]]
+```
+
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt
index 7e139dd..0360cc0 100644
--- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt
@@ -17,7 +17,7 @@
summary: "Computes the minimum along segments of a tensor."
description: <<END
Read
-[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#segmentation)
+[the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
for an explanation of segments.
This operator is similar to the unsorted segment sum operator found
@@ -31,6 +31,15 @@
possible value for the specific numeric type,
`output[i] = numeric_limits<T>::max()`.
+For example:
+
+``` python
+c = tf.constant([[1,2,3,4], [5,6,7,8], [4,3,2,1]])
+tf.unsorted_segment_min(c, tf.constant([0, 1, 0]), num_segments=2)
+# ==> [[ 1, 2, 2, 1],
+# [5, 6, 7, 8]]
+```
+
If the given segment ID `i` is negative, then the corresponding value is
dropped, and will not be included in the result.
END
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt
index 9c8ea3b..67de473 100644
--- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt
@@ -17,7 +17,7 @@
summary: "Computes the product along segments of a tensor."
description: <<END
Read
-[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#segmentation)
+[the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
for an explanation of segments.
This operator is similar to the unsorted segment sum operator found
@@ -28,6 +28,15 @@
\\(output_i = \prod_{j...} data[j...]\\) where the product is over tuples
`j...` such that `segment_ids[j...] == i`.
+For example:
+
+``` python
+c = tf.constant([[1,2,3,4], [5,6,7,8], [4,3,2,1]])
+tf.unsorted_segment_prod(c, tf.constant([0, 1, 0]), num_segments=2)
+# ==> [[ 4, 6, 6, 4],
+# [5, 6, 7, 8]]
+```
+
If there is no entry for a given segment ID `i`, it outputs 1.
If the given segment ID `i` is negative, then the corresponding value is
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt
index 7e5d926..0813923 100644
--- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt
@@ -17,7 +17,7 @@
summary: "Computes the sum along segments of a tensor."
description: <<END
Read
-[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+[the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
for an explanation of segments.
Computes a tensor such that
@@ -35,5 +35,13 @@
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="https://www.tensorflow.org/images/UnsortedSegmentSum.png" alt>
</div>
+
+``` python
+c = tf.constant([[1,2,3,4], [5,6,7,8], [4,3,2,1]])
+tf.unsorted_segment_sum(c, tf.constant([0, 1, 0]), num_segments=2)
+# ==> [[ 5, 5, 5, 5],
+# [5, 6, 7, 8]]
+```
+
END
}
diff --git a/tensorflow/core/api_def/excluded_ops.cc b/tensorflow/core/api_def/excluded_ops.cc
index 02026e9..65d2102 100644
--- a/tensorflow/core/api_def/excluded_ops.cc
+++ b/tensorflow/core/api_def/excluded_ops.cc
@@ -24,9 +24,9 @@
"GcsConfigureBlockCache", "GcsConfigureCredentials",
#ifdef INTEL_MKL
// QuantizedFusedOps for Intel CPU
- "QuantizedConv2DAndRequantize", "QuantizedConv2DWithBias",
- "QuantizedConv2DWithBiasAndRequantize", "QuantizedConv2DAndRelu",
- "QuantizedConv2DAndReluAndRequantize",
+ "QuantizedConcatV2", "QuantizedConv2DAndRequantize",
+ "QuantizedConv2DWithBias", "QuantizedConv2DWithBiasAndRequantize",
+ "QuantizedConv2DAndRelu", "QuantizedConv2DAndReluAndRequantize",
"QuantizedConv2DWithBiasAndRelu",
"QuantizedConv2DWithBiasAndReluAndRequantize",
"QuantizedConv2DWithBiasSumAndRelu",
diff --git a/tensorflow/core/api_def/python_api/api_def_DecodeWav.pbtxt b/tensorflow/core/api_def/python_api/api_def_DecodeWav.pbtxt
index d6fd469..28f4514 100644
--- a/tensorflow/core/api_def/python_api/api_def_DecodeWav.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_DecodeWav.pbtxt
@@ -1,4 +1,6 @@
op {
graph_op_name: "DecodeWav"
- visibility: HIDDEN
+ endpoint {
+ name: "audio.decode_wav"
+ }
}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListConcatV2.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListConcatV2.pbtxt
new file mode 100644
index 0000000..237774a
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListConcatV2.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListConcatV2"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/common_runtime/base_collective_executor.cc b/tensorflow/core/common_runtime/base_collective_executor.cc
index 03c0e9c..8870a53 100644
--- a/tensorflow/core/common_runtime/base_collective_executor.cc
+++ b/tensorflow/core/common_runtime/base_collective_executor.cc
@@ -301,6 +301,8 @@
for (int32 instance : col_params.instance.impl_details.dependencies) {
auto find_iter = launched_.find(instance);
if (find_iter == launched_.end() || find_iter->second != 0) {
+ VLOG(1) << "Collective " << col_params.ToString()
+ << " blocked by instance " << instance;
return false;
}
}
@@ -313,6 +315,7 @@
while (!CheckDependencies(col_params)) {
launch_cv_.wait(l);
}
+ VLOG(1) << "Unblocking collective " << col_params.ToString();
}
void BaseCollectiveExecutor::Launched(const CollectiveParams& col_params) {
@@ -325,6 +328,8 @@
launched_[col_params.instance.instance_key] = num_devices;
}
if (--launched_[col_params.instance.instance_key] == 0) {
+ VLOG(1) << "Unblocking dependencies for collective instance "
+ << col_params.instance.instance_key;
launch_cv_.notify_all();
}
}
diff --git a/tensorflow/core/common_runtime/bfc_allocator.cc b/tensorflow/core/common_runtime/bfc_allocator.cc
index 3843ea9..c7e535c 100644
--- a/tensorflow/core/common_runtime/bfc_allocator.cc
+++ b/tensorflow/core/common_runtime/bfc_allocator.cc
@@ -18,6 +18,7 @@
#include "tensorflow/core/common_runtime/bfc_allocator.h"
#include "tensorflow/core/common_runtime/allocator_retry.h"
+#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/lib/strings/numbers.h"
@@ -152,6 +153,7 @@
c->allocation_id = -1;
c->prev = kInvalidChunkHandle;
c->next = kInvalidChunkHandle;
+ c->freed_count = 0;
region_manager_.set_handle(c->ptr, h);
@@ -180,29 +182,46 @@
free_chunks_list_ = h;
}
-void* BFCAllocator::AllocateRaw(size_t unused_alignment, size_t num_bytes) {
+void* BFCAllocator::AllocateRawInternalWithRetry(
+ size_t unused_alignment, size_t num_bytes,
+ const AllocationAttributes& allocation_attr) {
// Fast path: Try once to allocate without getting the retry_helper_ involved
- void* r = AllocateRawInternal(unused_alignment, num_bytes, false);
+ uint64 freed_by_count = 0;
+ if (allocation_attr.freed_by_func != nullptr) {
+ freed_by_count = allocation_attr.freed_by_func();
+ }
+ void* r =
+ AllocateRawInternal(unused_alignment, num_bytes, false, freed_by_count);
if (r != nullptr) {
return r;
} else {
static const int64 kMaxMillisToWait = 10000; // 10 seconds
- return retry_helper_.AllocateRaw(
- [this](size_t a, size_t nb, bool v) {
- return AllocateRawInternal(a, nb, v);
+ r = retry_helper_.AllocateRaw(
+ [this, &allocation_attr](size_t a, size_t nb, bool v) {
+ uint64 freed_by_count = 0;
+ if (allocation_attr.freed_by_func != nullptr) {
+ freed_by_count = allocation_attr.freed_by_func();
+ }
+ return AllocateRawInternal(a, nb, v, freed_by_count);
},
kMaxMillisToWait, unused_alignment, num_bytes);
+ return r;
}
}
void* BFCAllocator::AllocateRaw(size_t unused_alignment, size_t num_bytes,
const AllocationAttributes& allocation_attr) {
+ VLOG(1) << "AllocateRaw " << Name() << " " << num_bytes;
if (allocation_attr.no_retry_on_failure) {
// Return immediately upon the first failure if this is for allocating an
// optional scratch space.
bool dump_log_on_failure = VLOG_IS_ON(2);
- void* result =
- AllocateRawInternal(unused_alignment, num_bytes, dump_log_on_failure);
+ uint64 freed_by_count = 0;
+ if (allocation_attr.freed_by_func != nullptr) {
+ freed_by_count = allocation_attr.freed_by_func();
+ }
+ void* result = AllocateRawInternal(unused_alignment, num_bytes,
+ dump_log_on_failure, freed_by_count);
if (result == nullptr) {
static std::atomic<int32> log_counter{0};
int32 counter_value = log_counter.load(std::memory_order_relaxed);
@@ -218,7 +237,8 @@
}
return result;
} else {
- return AllocateRaw(unused_alignment, num_bytes);
+ return AllocateRawInternalWithRetry(unused_alignment, num_bytes,
+ allocation_attr);
}
}
@@ -233,7 +253,8 @@
void* BFCAllocator::AllocateRawInternal(size_t unused_alignment,
size_t num_bytes,
- bool dump_log_on_failure) {
+ bool dump_log_on_failure,
+ uint64 freed_before) {
if (num_bytes == 0) {
LOG(ERROR) << "tried to allocate 0 bytes";
return nullptr;
@@ -247,14 +268,14 @@
BinNum bin_num = BinNumForSize(rounded_bytes);
mutex_lock l(lock_);
- void* ptr = FindChunkPtr(bin_num, rounded_bytes, num_bytes);
+ void* ptr = FindChunkPtr(bin_num, rounded_bytes, num_bytes, freed_before);
if (ptr != nullptr) {
return ptr;
}
// Try to extend
if (Extend(unused_alignment, rounded_bytes)) {
- ptr = FindChunkPtr(bin_num, rounded_bytes, num_bytes);
+ ptr = FindChunkPtr(bin_num, rounded_bytes, num_bytes, freed_before);
if (ptr != nullptr) {
return ptr;
}
@@ -274,7 +295,7 @@
}
void* BFCAllocator::FindChunkPtr(BinNum bin_num, size_t rounded_bytes,
- size_t num_bytes) {
+ size_t num_bytes, uint64 freed_before) {
// First identify the first bin that could satisfy rounded_bytes.
for (; bin_num < kNumBins; bin_num++) {
// Start searching from the first bin for the smallest chunk that fits
@@ -285,6 +306,9 @@
const BFCAllocator::ChunkHandle h = (*citer);
BFCAllocator::Chunk* chunk = ChunkFromHandle(h);
DCHECK(!chunk->in_use());
+ if (freed_before > 0 && freed_before < chunk->freed_count) {
+ continue;
+ }
if (chunk->size >= rounded_bytes) {
// We found an existing chunk that fits us that wasn't in use, so remove
// it from the free bin structure prior to using.
@@ -347,6 +371,9 @@
// The new chunk is not in use.
new_chunk->allocation_id = -1;
+ // It inherits the freed time.
+ new_chunk->freed_count = c->freed_count;
+
// Maintain the pointers.
// c <-> c_neighbor becomes
// c <-> new_chunk <-> c_neighbor
@@ -415,6 +442,9 @@
// Set the new size
c1->size += c2->size;
+ // Pick latest free time.
+ c1->freed_count = std::max(c1->freed_count, c2->freed_count);
+
DeleteChunk(h2);
}
@@ -460,6 +490,11 @@
// Mark the chunk as no longer in use.
c->allocation_id = -1;
+ // Optionally record the free time.
+ if (timing_counter_) {
+ c->freed_count = timing_counter_->next();
+ }
+
// Updates the stats.
stats_.bytes_in_use -= c->size;
@@ -630,7 +665,10 @@
in_use_by_size[c->size]++;
}
LOG(INFO) << (c->in_use() ? "Chunk" : "Free ") << " at " << c->ptr
- << " of size " << c->size;
+ << " of size " << c->size
+ << (timing_counter_
+ ? strings::StrCat(" freed_count ", c->freed_count)
+ : "");
h = c->next;
}
}
diff --git a/tensorflow/core/common_runtime/bfc_allocator.h b/tensorflow/core/common_runtime/bfc_allocator.h
index 2d74bf2..261bacb 100644
--- a/tensorflow/core/common_runtime/bfc_allocator.h
+++ b/tensorflow/core/common_runtime/bfc_allocator.h
@@ -23,6 +23,7 @@
#include <vector>
#include "tensorflow/core/common_runtime/allocator_retry.h"
+#include "tensorflow/core/common_runtime/shared_counter.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -50,9 +51,14 @@
~BFCAllocator() override;
string Name() override { return name_; }
- void* AllocateRaw(size_t alignment, size_t num_bytes) override;
+
+ void* AllocateRaw(size_t alignment, size_t num_bytes) override {
+ return AllocateRaw(alignment, num_bytes, AllocationAttributes());
+ }
+
void* AllocateRaw(size_t alignment, size_t num_bytes,
const AllocationAttributes& allocation_attr) override;
+
void DeallocateRaw(void* ptr) override;
bool TracksAllocationSizes() override;
@@ -67,11 +73,19 @@
void ClearStats() override;
+ void SetTimingCounter(SharedCounter* sc) { timing_counter_ = sc; }
+
private:
struct Bin;
void* AllocateRawInternal(size_t alignment, size_t num_bytes,
- bool dump_log_on_failure);
+ bool dump_log_on_failure,
+ uint64 freed_before_count);
+
+ void* AllocateRawInternalWithRetry(
+ size_t alignment, size_t num_bytes,
+ const AllocationAttributes& allocation_attr);
+
void DeallocateRawInternal(void* ptr);
// A ChunkHandle is an index into the chunks_ vector in BFCAllocator
@@ -126,6 +140,9 @@
// What bin are we in?
BinNum bin_num = kInvalidBinNum;
+ // Optional count when this chunk was most recently made free.
+ uint64 freed_count = 0;
+
bool in_use() const { return allocation_id != -1; }
string DebugString(BFCAllocator* a,
@@ -314,8 +331,8 @@
// Returns a pointer to an underlying allocated chunk of size
// 'rounded_bytes'.
- void* FindChunkPtr(BinNum bin_num, size_t rounded_bytes, size_t num_bytes)
- EXCLUSIVE_LOCKS_REQUIRED(lock_);
+ void* FindChunkPtr(BinNum bin_num, size_t rounded_bytes, size_t num_bytes,
+ uint64 freed_before) EXCLUSIVE_LOCKS_REQUIRED(lock_);
// Splits the chunk specified by 'h' into two chunks, one at least
// of size 'num_bytes'.
@@ -420,6 +437,7 @@
std::unique_ptr<SubAllocator> sub_allocator_;
string name_;
+ SharedCounter* timing_counter_ = nullptr;
// Structures mutable after construction
mutable mutex lock_;
diff --git a/tensorflow/core/common_runtime/buf_rendezvous_test.cc b/tensorflow/core/common_runtime/buf_rendezvous_test.cc
index 0e79823..7621787 100644
--- a/tensorflow/core/common_runtime/buf_rendezvous_test.cc
+++ b/tensorflow/core/common_runtime/buf_rendezvous_test.cc
@@ -109,7 +109,7 @@
TEST_F(BufRendezvousTest, ErrorDuplicatePut) {
bool prod_callback_called = false;
br_->ProvideBuf("key0", fake_dev_ptr_, fake_dev_ctx_, &a_, aa_,
- [this, &prod_callback_called](const Status& s) {
+ [&prod_callback_called](const Status& s) {
prod_callback_called = true;
});
Status bad_status;
@@ -129,11 +129,11 @@
TEST_F(BufRendezvousTest, ErrorDeleteNonEmpty) {
Status cons_status;
- br_->ConsumeBuf(
- "key0", [this, &cons_status](const Status& s, BufRendezvous::Hook* h) {
- cons_status = s;
- EXPECT_EQ(h, nullptr);
- });
+ br_->ConsumeBuf("key0",
+ [&cons_status](const Status& s, BufRendezvous::Hook* h) {
+ cons_status = s;
+ EXPECT_EQ(h, nullptr);
+ });
EXPECT_TRUE(cons_status.ok());
br_.reset();
EXPECT_FALSE(cons_status.ok());
@@ -146,13 +146,13 @@
Status prod_status;
Notification prod_note;
Notification cons_note;
- br_->ConsumeBuf("key0", [this, &cons_note, &cons_status](
- const Status& s, BufRendezvous::Hook* h) {
+ br_->ConsumeBuf("key0", [&cons_note, &cons_status](const Status& s,
+ BufRendezvous::Hook* h) {
cons_status = s;
cons_note.Notify();
});
br_->ProvideBuf("key1", fake_dev_ptr_, fake_dev_ctx_, &a_, aa_,
- [this, &prod_note, &prod_status](const Status& s) {
+ [&prod_note, &prod_status](const Status& s) {
prod_status = s;
prod_note.Notify();
});
@@ -175,13 +175,13 @@
Status prod_status;
Notification prod_note;
Notification cons_note;
- br_->ConsumeBuf("key0", [this, &cons_note, &cons_status](
- const Status& s, BufRendezvous::Hook* h) {
+ br_->ConsumeBuf("key0", [&cons_note, &cons_status](const Status& s,
+ BufRendezvous::Hook* h) {
cons_status = s;
cons_note.Notify();
});
br_->ProvideBuf("key1", fake_dev_ptr_, fake_dev_ctx_, &a_, aa_,
- [this, &prod_note, &prod_status](const Status& s) {
+ [&prod_note, &prod_status](const Status& s) {
prod_status = s;
prod_note.Notify();
});
diff --git a/tensorflow/core/common_runtime/collective_executor_mgr_test.cc b/tensorflow/core/common_runtime/collective_executor_mgr_test.cc
index f3d86aa..3eef5ed 100644
--- a/tensorflow/core/common_runtime/collective_executor_mgr_test.cc
+++ b/tensorflow/core/common_runtime/collective_executor_mgr_test.cc
@@ -44,7 +44,7 @@
std::unique_ptr<DeviceResolverInterface> drl(
new DeviceResolverLocal(device_mgr_.get()));
std::unique_ptr<ParamResolverInterface> prl(
- new CollectiveParamResolverLocal(device_mgr_.get(), drl.get(),
+ new CollectiveParamResolverLocal(cp, device_mgr_.get(), drl.get(),
task_name));
cme_.reset(new CollectiveExecutorMgr(cp, device_mgr_.get(), std::move(drl),
std::move(prl)));
@@ -73,11 +73,11 @@
EXPECT_EQ(CollectiveExecutor::kInvalidId, cme_->NextStepId(123));
Notification ss_note;
Status ss_status;
- cme_->RefreshStepIdSequenceAsync(
- 123, [this, &ss_status, &ss_note](const Status& s) {
- ss_status = s;
- ss_note.Notify();
- });
+ cme_->RefreshStepIdSequenceAsync(123,
+ [&ss_status, &ss_note](const Status& s) {
+ ss_status = s;
+ ss_note.Notify();
+ });
ss_note.WaitForNotification();
EXPECT_FALSE(ss_status.ok());
EXPECT_EQ(ss_status.error_message(),
@@ -87,7 +87,7 @@
GetStepSequenceRequest* req = nullptr;
GetStepSequenceResponse* resp = nullptr;
cme_->GetStepSequenceAsync(req, resp,
- [this, &gs_status, &gs_note](const Status& s) {
+ [&gs_status, &gs_note](const Status& s) {
gs_status = s;
gs_note.Notify();
});
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
index 8907f6d..5acba6e 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local.cc
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
@@ -38,9 +38,9 @@
}
CollectiveParamResolverLocal::CollectiveParamResolverLocal(
- const DeviceMgr* dev_mgr, DeviceResolverInterface* dev_resolver,
- const string& task_name)
- : nccl_(false), // (b/111897089): turn on NCCL collectives.
+ const ConfigProto& config, const DeviceMgr* dev_mgr,
+ DeviceResolverInterface* dev_resolver, const string& task_name)
+ : nccl_(config.experimental().collective_nccl()),
dev_mgr_(dev_mgr),
dev_resolver_(dev_resolver),
task_name_(task_name) {}
@@ -144,7 +144,6 @@
}
namespace {
-
struct DevRec {
string task;
string device;
@@ -361,7 +360,7 @@
for (int i = 0; i < perm.size(); ++i) {
perm[i] = i;
}
- std::sort(perm.begin(), perm.end(), [cp](const int& a, const int& b) {
+ std::sort(perm.begin(), perm.end(), [cp](int a, int b) {
return cp->instance.device_names[a] < cp->instance.device_names[b];
});
std::vector<string> new_devs;
@@ -585,7 +584,7 @@
void CollectiveParamResolverLocal::CompleteParamsAsync(
const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
const StatusCallback& done) {
- VLOG(1) << "CompleteParams " << device << " for " << cp << ": "
+ VLOG(1) << "CompleteParams local " << device << " for " << cp << ": "
<< cp->ToString();
CompleteGroupLocal(
device, cp,
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.h b/tensorflow/core/common_runtime/collective_param_resolver_local.h
index fd408e4..08e2f33 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local.h
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.h
@@ -23,6 +23,7 @@
#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "tensorflow/core/protobuf/config.pb.h"
namespace tensorflow {
class CompleteGroupRequest;
@@ -36,7 +37,8 @@
// group leader for param resolution in a multi-task context.
class CollectiveParamResolverLocal : public ParamResolverInterface {
public:
- CollectiveParamResolverLocal(const DeviceMgr* dev_mgr,
+ CollectiveParamResolverLocal(const ConfigProto& config,
+ const DeviceMgr* dev_mgr,
DeviceResolverInterface* dev_resolver,
const string& task_name);
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc
index 94d889c..70eb9f8 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc
@@ -41,8 +41,8 @@
TF_CHECK_OK(DeviceFactory::AddDevices(options, task_name, &devices));
device_mgr_.reset(new DeviceMgr(std::move(devices)));
drl_.reset(new DeviceResolverLocal(device_mgr_.get()));
- prl_.reset(new CollectiveParamResolverLocal(device_mgr_.get(), drl_.get(),
- task_name));
+ prl_.reset(new CollectiveParamResolverLocal(cp, device_mgr_.get(),
+ drl_.get(), task_name));
}
void RunCompleteDefaultRanking(
@@ -175,7 +175,7 @@
Env::Default()->SchedClosure([this, i, cp, ¬e, &statuses]() {
prl_->CompleteParamsAsync(cp->instance.device_names[0], cp,
nullptr /*CancellationManager*/,
- [this, &statuses, ¬e, i](const Status& s) {
+ [&statuses, ¬e, i](const Status& s) {
statuses[i] = s;
note[i].Notify();
});
diff --git a/tensorflow/core/common_runtime/collective_rma_local_test.cc b/tensorflow/core/common_runtime/collective_rma_local_test.cc
index 4263f3a..2e9d8cd 100644
--- a/tensorflow/core/common_runtime/collective_rma_local_test.cc
+++ b/tensorflow/core/common_runtime/collective_rma_local_test.cc
@@ -46,8 +46,8 @@
TF_CHECK_OK(DeviceFactory::AddDevices(options, kTaskName, &devices));
device_mgr_.reset(new DeviceMgr(std::move(devices)));
drl_.reset(new DeviceResolverLocal(device_mgr_.get()));
- prl_.reset(new CollectiveParamResolverLocal(device_mgr_.get(), drl_.get(),
- kTaskName));
+ prl_.reset(new CollectiveParamResolverLocal(cp, device_mgr_.get(),
+ drl_.get(), kTaskName));
rma_.reset(new CollectiveRemoteAccessLocal(device_mgr_.get(), drl_.get(),
kStepId));
}
@@ -70,7 +70,7 @@
"key_0", cpu0 /*to_device*/, nullptr /*to_device_ctx*/,
attr /*to_alloc_attr*/, &sink_tensor, dev_locality,
0 /*stream_index*/,
- [this, &recv_note, &recv_status](const Status& s) {
+ [&recv_note, &recv_status](const Status& s) {
recv_status = s;
recv_note.Notify();
});
@@ -85,7 +85,7 @@
rma_->PostToPeer(kTaskName + "/device:CPU:0", kTaskName, "key_0",
cpu0 /*from_device*/, nullptr /*from_device_ctx*/,
attr /*to_alloc_attr*/, &source_tensor, dev_locality,
- [this, &send_note, &send_status](const Status& s) {
+ [&send_note, &send_status](const Status& s) {
send_status = s;
send_note.Notify();
});
@@ -113,7 +113,7 @@
"key_0", cpu2 /*to_device*/, nullptr /*to_device_ctx*/,
attr /*to_alloc_attr*/, &sink_tensor, dev_locality,
0 /*stream_index*/,
- [this, &recv_note, &recv_status](const Status& s) {
+ [&recv_note, &recv_status](const Status& s) {
recv_status = s;
recv_note.Notify();
});
@@ -130,7 +130,7 @@
rma_->PostToPeer(kTaskName + "/device:CPU:2", kTaskName, "key_0",
cpu1 /*from_device*/, nullptr /*from_device_ctx*/,
attr /*to_alloc_attr*/, &source_tensor, dev_locality,
- [this, &send_note, &send_status](const Status& s) {
+ [&send_note, &send_status](const Status& s) {
send_status = s;
send_note.Notify();
});
diff --git a/tensorflow/core/common_runtime/device.h b/tensorflow/core/common_runtime/device.h
index 5a0ef28..64119e8 100644
--- a/tensorflow/core/common_runtime/device.h
+++ b/tensorflow/core/common_runtime/device.h
@@ -127,16 +127,22 @@
// flag settings. Override this to return false for devices that don't allow
// such calls. Instead, these devices must use other mechanisms (such as
// num_deferred_ops) to ensure the device has finished processing necessary
- // work at session completion.
+ // work at session completion. In addition, for these devices, RefreshStatus
+ // must be called at session completion to retrieve execution result status.
//
- // Devices that override this function must also implement CurrentStatus.
+ // Devices that override this function must also implement RefreshStatus.
virtual bool AllowsSyncOnCompletion() const { return true; }
// This is used in conjunction with AllowsSyncOnCompletion to allow the
// executor to get execution result status at session completion.
- virtual Status CurrentStatus() {
+ //
+ // For supported devices, this call returns the underlying device stream's
+ // current status in a non-blocking way, without using blocking calls such as
+ // Stream::BlockHostUntilDone or Device::Sync. When applicable, the device
+ // status is also updated with the retrieved stream status.
+ virtual Status RefreshStatus() {
return errors::Unimplemented(
- "CurrentStatus is not supported on this device.");
+ "RefreshStatus is not supported on this device.");
}
// Optionally modify the device's GraphDef before execution.
diff --git a/tensorflow/core/common_runtime/device_resolver_local_test.cc b/tensorflow/core/common_runtime/device_resolver_local_test.cc
index 54f1119..62e82bc 100644
--- a/tensorflow/core/common_runtime/device_resolver_local_test.cc
+++ b/tensorflow/core/common_runtime/device_resolver_local_test.cc
@@ -56,7 +56,7 @@
Notification note;
Status status;
drl_->GetDeviceLocalitiesAsync(cp.instance, &localities,
- [this, ¬e, &status](const Status& s) {
+ [¬e, &status](const Status& s) {
status = s;
note.Notify();
});
@@ -74,7 +74,7 @@
Notification note;
Status status;
drl_->GetDeviceLocalitiesAsync(cp.instance, &localities,
- [this, ¬e, &status](const Status& s) {
+ [¬e, &status](const Status& s) {
status = s;
note.Notify();
});
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 80b62f2..40a1ffc 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -501,7 +501,8 @@
std::unique_ptr<DeviceResolverInterface> drl(
new DeviceResolverLocal(device_mgr_.get()));
std::unique_ptr<ParamResolverInterface> cprl(
- new CollectiveParamResolverLocal(device_mgr_.get(), drl.get(),
+ new CollectiveParamResolverLocal(options_.config, device_mgr_.get(),
+ drl.get(),
"/job:localhost/replica:0/task:0"));
collective_executor_mgr_.reset(new CollectiveExecutorMgr(
options_.config, device_mgr_.get(), std::move(drl), std::move(cprl)));
@@ -1194,6 +1195,8 @@
if (options_.config.experimental()
.collective_deterministic_sequential_execution()) {
options.collective_order = GraphCollectiveOrder::kEdges;
+ } else if (options_.config.experimental().collective_nccl()) {
+ options.collective_order = GraphCollectiveOrder::kAttrs;
}
std::unique_ptr<FunctionInfo> func_info(new FunctionInfo);
diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD
index 6f2680b..aef64da 100644
--- a/tensorflow/core/common_runtime/eager/BUILD
+++ b/tensorflow/core/common_runtime/eager/BUILD
@@ -191,7 +191,9 @@
"//tensorflow/core:lib",
"//tensorflow/core:math_ops_op_lib",
"//tensorflow/core:nn_ops_op_lib",
+ "//tensorflow/core:no_op_op_lib",
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:sendrecv_ops_op_lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/kernels:constant_op",
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index 8068905..12e6483 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -20,9 +20,11 @@
#include "tensorflow/core/common_runtime/device_resolver_local.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/process_util.h"
+#ifndef __ANDROID__
#include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
#include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
#include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
+#endif
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/util/env_var.h"
@@ -81,7 +83,8 @@
std::unique_ptr<DeviceResolverInterface> drl(
new DeviceResolverLocal(local_device_mgr()));
std::unique_ptr<ParamResolverInterface> cprl(new CollectiveParamResolverLocal(
- local_device_mgr(), drl.get(), "/job:localhost/replica:0/task:0"));
+ opts.config, local_device_mgr(), drl.get(),
+ "/job:localhost/replica:0/task:0"));
collective_executor_mgr_.reset(new CollectiveExecutorMgr(
opts.config, local_device_mgr(), std::move(drl), std::move(cprl)));
}
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index 53e9ba2..c6e8573 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -252,6 +252,16 @@
first->parsed_name().task == second->parsed_name().task;
}
+// Gets the CPU device on the task of device.
+Status CPUDeviceOnTask(EagerContext* ctx, tensorflow::Device* device,
+ tensorflow::Device** cpu_device) {
+ string cpu_device_name;
+ TF_RETURN_IF_ERROR(DeviceNameUtils::DeviceNameToCpuDeviceName(
+ device->name(), &cpu_device_name));
+
+ return ctx->FindDeviceByName(cpu_device_name, cpu_device);
+}
+
inline tensorflow::Fprint128 FingerprintCat128(const tensorflow::Fprint128& a,
const tensorflow::Fprint128& b) {
return {tensorflow::FingerprintCat64(a.low64, b.low64),
@@ -628,10 +638,16 @@
// explicitly copy, and instead depend on the copy to happen locally
// when the op is executed on the device.
!OnSameTask(ctx, op->Device(), input_device)) {
+ tensorflow::Device* remote_cpu_device;
+ TF_RETURN_IF_ERROR(
+ CPUDeviceOnTask(ctx, op->Device(), &remote_cpu_device));
// TODO(b/110044833): It's possible the same tensor gets copied to the
// remote device repeatedly.
+ // Always copy to the remote CPU so that the actual device can be
+ // correctly determined after the kernel is selected/instantiated, since
+ // the op might have its inputs on host memory.
TF_RETURN_IF_ERROR(MaybeCopyInputToExpectedDevice(
- op, op->Device(), i, op->Device(), /* run_metadata= */ nullptr,
+ op, op->Device(), i, remote_cpu_device, /* run_metadata= */ nullptr,
&(*op->MutableInputs())[i]));
}
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
index 09f60a7..41b4608 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
@@ -96,31 +96,24 @@
"Failed to parse config_proto attribute as tensorflow::ConfigProto "
"proto.");
}
- // We are going to execute the graph via function library runtime, and
- // because function execution semantics is slightly different from the
- // regular tensorlow graph, we need to make sure that Grappler respects it
- // when doing it's optimization passes (e.g. do not prune stateful and
- // dataset ops).
grappler::GrapplerItem::OptimizationOptions optimization_options;
- optimization_options.is_function_instantiation = true;
- // Keras graphs expected to be executed with regular graph execution
- // semantics (it's allowed to prune stateful and dataset ops).
- if (absl::StrContains(function_def->signature().name(), "keras_graph")) {
- optimization_options.is_function_instantiation = false;
- }
+ // Tensorflow 2.0 in eager mode with automatic control dependencies will
+ // prune all nodes that are not in the transitive fanin of the fetch nodes.
+ // However because the function will be executed via FunctionLibraryRuntime,
+ // and current function implementation does not prune stateful and dataset
+ // ops, we rely on Grappler to do the correct graph pruning.
+ optimization_options.allow_pruning_stateful_and_dataset_ops = true;
- // Wrapped function expects execution semantics to be the same as
- // `session.run`, so we should prune unreachable stateful and dataset ops.
- if (absl::StrContains(function_def->signature().name(),
- "wrapped_function")) {
- optimization_options.is_function_instantiation = false;
- }
+ // All the nested function calls will be executed and optimized via
+ // PartitionedCallOp, there is no need to optimize functions now.
+ optimization_options.optimize_function_library = false;
options.optimize_graph_fn = std::bind(
grappler::OptimizeGraph, std::placeholders::_1, std::placeholders::_2,
std::placeholders::_3, std::placeholders::_4, config_proto,
- optimization_options, std::placeholders::_5);
+ function_def->signature().name(), optimization_options,
+ std::placeholders::_5);
}
#endif
options.graph_collector = graph_collector;
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h
index c4ea99f..027168d 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device.h
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h
@@ -151,7 +151,7 @@
Device* OutputResourceDevice(int idx) const override;
DataType input_type(int i) const override;
- const DataTypeVector& output_dtypes() const {
+ const DataTypeVector& output_dtypes() const override {
return kernel_->output_types();
}
int num_inputs() const override { return kernel_->num_inputs(); }
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index d068bbf..05f3e85 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -2432,11 +2432,11 @@
// There are several potential race conditions below. To name a few:
// 1. Even if the device's status is OK at the precise moment when
- // num_deferred_ops_ reaches 0, it could go bad before device->CurrentStatus()
+ // num_deferred_ops_ reaches 0, it could go bad before device->RefreshStatus()
// is called below, caused by work enqueued onto the same device by other
// concurrent ExecutorState objects.
- // 2. Some implementations of Device::CurrentStatus, such as
- // XlaDevice::CurrentStatus, may be inherently racy because it releases the
+ // 2. Some implementations of Device::RefreshStatus, such as
+ // XlaDevice::RefreshStatus, may be inherently racy because it releases the
// device mutex after a stream pointer is acquired and before the stream is
// queried for status.
// 3. It's the same for some implementations of Device::Sync, such as
@@ -2462,7 +2462,7 @@
// these devices should have used num_deferred_ops correctly to ensure the
// device has finished all relevant work at this point.
if (!device->AllowsSyncOnCompletion()) {
- status.Update(device->CurrentStatus());
+ status.Update(device->RefreshStatus());
delete this;
runner([=]() { done_cb(status); });
return;
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
index 010fdff..80d221a 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -276,6 +276,28 @@
sync_every_op_(sync_every_op),
max_streams_(max_streams) {
GPUProcessState::singleton()->EnableGPUDevice();
+ pending_cap_ = options.config.gpu_options().experimental().pending_cap();
+ timestamped_allocator_ =
+ options.config.gpu_options().experimental().timestamped_allocator();
+ if (timestamped_allocator_ || pending_cap_ > 0) {
+ std::unique_ptr<SharedCounter> timing_counter;
+ if (timestamped_allocator_) {
+ // In this case the SharedCounter was already created and set in the
+ // associated Allocator, with ownership by GPUProcessState. Here we take
+ // over ownership of that SharedAllocator to transfer it to the
+ // GPUKernelTracker.
+ timing_counter =
+ GPUProcessState::singleton()->ReleaseGPUAllocatorCounter(tf_gpu_id);
+ DCHECK(timing_counter.get());
+ } else {
+ DCHECK_GT(pending_cap_, 0);
+ // In this case we need a SharedCounter to be owned by GPUKernelTracker
+ // but one was not created for use by the Allocator, so we create one.
+ timing_counter.reset(new SharedCounter);
+ }
+ kernel_tracker_.reset(
+ new GPUKernelTracker(Env::Default(), std::move(timing_counter)));
+ }
}
BaseGPUDevice::~BaseGPUDevice() {
@@ -508,6 +530,10 @@
if (idc->stream() != stream) stream->ThenWaitFor(idc->stream());
}
}
+ if (pending_cap_ > 0) {
+ DCHECK(kernel_tracker_);
+ kernel_tracker_->PauseWhilePendingExceeds(pending_cap_);
+ }
se::cuda::ScopedActivateExecutorContext scoped_activation{stream->parent()};
op_kernel->Compute(context);
if (context->status().ok()) {
@@ -525,6 +551,14 @@
VLOG(1) << "GpuDevice::ComputeHelper scheduled "
<< ComputeOpKernelDebugString(*op_kernel, stream_id);
}
+ if (kernel_tracker_) {
+ GPUKernelTracker* tracker = kernel_tracker_.get();
+ DCHECK(tracker);
+ uint64 queued_count = tracker->RecordQueued();
+ em_->ThenExecute(stream, [op_kernel, tracker, queued_count]() {
+ tracker->RecordTerminated(queued_count);
+ });
+ }
} else {
if (vlog_1) {
VLOG(1) << "GpuDevice::ComputeHelper failed to schedule "
@@ -721,8 +755,8 @@
if (!strings::safe_strto32(platform_gpu_id_str, &platform_gpu_id)) {
return errors::InvalidArgument(
"Could not parse entry in 'visible_device_list': '",
- platform_gpu_id_str, "'. visible_device_list = ",
- visible_device_list);
+ platform_gpu_id_str,
+ "'. visible_device_list = ", visible_device_list);
}
if (platform_gpu_id < 0 ||
platform_gpu_id >= gpu_manager->VisibleDeviceCount()) {
@@ -957,15 +991,15 @@
for (PlatformGpuId platform_gpu_id : valid_platform_gpu_ids) {
err = cudaSetDevice(platform_gpu_id.value());
if (err != cudaSuccess) {
- return errors::Internal("cudaSetDevice() on GPU:",
- platform_gpu_id.value(), " failed. Status: ",
- cudaGetErrorString(err));
+ return errors::Internal(
+ "cudaSetDevice() on GPU:", platform_gpu_id.value(),
+ " failed. Status: ", cudaGetErrorString(err));
}
err = cudaFree(nullptr);
if (err != cudaSuccess) {
return errors::Internal("CUDA runtime implicit initialization on GPU:",
- platform_gpu_id.value(), " failed. Status: ",
- cudaGetErrorString(err));
+ platform_gpu_id.value(),
+ " failed. Status: ", cudaGetErrorString(err));
}
}
// Reset to the original device.
@@ -1517,6 +1551,115 @@
return Status::OK();
}
+uint64 BaseGPUDevice::SafeAllocFrontier() {
+ if (timestamped_allocator_) {
+ return kernel_tracker_->LastTerminatedCount();
+ } else {
+ return 0;
+ }
+}
+
+int BaseGPUDevice::PendingKernels() {
+ if (kernel_tracker_) {
+ return kernel_tracker_->NumPending();
+ }
+ return 0;
+}
+
+uint64 GPUKernelTracker::RecordQueued() {
+ mutex_lock l(mu_);
+ uint64 queued_count = timing_counter_->next();
+ VLOG(2) << "RecordQueued queued_count=" << queued_count
+ << " first_available_=" << first_available_
+ << " last_completed_=" << last_completed_
+ << " num_pending_=" << num_pending_;
+ pending_kernels_[first_available_].queued_count = queued_count;
+ pending_kernels_[first_available_].terminated = false;
+ ++first_available_;
+ ++num_pending_;
+ if (first_available_ >= pending_kernels_.size()) {
+ first_available_ = 0;
+ }
+ if (first_available_ == last_completed_) {
+ // Ring buffer is full: double it. All of the same valid PendingKernel
+ // entries exist after the copy, they are just shifted to begin
+ // at index 0 in the new array.
+ std::vector<PendingKernel> new_buffer(pending_kernels_.size() * 2);
+ for (int i = 0; i < pending_kernels_.size(); ++i) {
+ int j = (i + last_completed_) % pending_kernels_.size();
+ new_buffer[i] = pending_kernels_[j];
+ }
+ last_completed_ = 0;
+ first_available_ = pending_kernels_.size();
+ pending_kernels_.swap(new_buffer);
+ VLOG(1) << "last_completed_=" << last_completed_
+ << " first_available_=" << first_available_
+ << " num_pending_=" << num_pending_;
+ }
+ DCHECK_NE(first_available_, last_completed_) << "exhausted pending_kernels";
+ return queued_count;
+}
+
+void GPUKernelTracker::RecordTerminated(uint64 queued_count) {
+ mutex_lock l(mu_);
+ VLOG(2) << "RecordTerminated queued_count=" << queued_count
+ << " first_available_=" << first_available_
+ << " last_completed_=" << last_completed_
+ << " num_pending_=" << num_pending_ << " LC="
+ << ((last_completed_ >= 0)
+ ? pending_kernels_[last_completed_].queued_count
+ : -1);
+ DCHECK_NE(first_available_, last_completed_);
+ DCHECK_GT(num_pending_, 0);
+ // Starting just past the last completed entry, find the entry with
+ // this queued_count and mark it done.
+ int index = (last_completed_ + 1) % pending_kernels_.size();
+ while (true) {
+ if (index == first_available_) {
+ // This should never happen.
+ LOG(FATAL) << "Failed to find " << queued_count // Crash OK
+ << " in queue";
+ }
+ if (pending_kernels_[index].queued_count == queued_count) {
+ pending_kernels_[index].terminated = true;
+ break;
+ }
+ index = (index + 1) % pending_kernels_.size();
+ }
+ // Next move last_completed_ forward past all completed kernels. In theory
+ // kernels should always complete in queued order so we should be able to
+ // advance the completed frontier to the last queued PendingKernel. In
+ // practice we occassionally see the termination callbacks arrive out of order
+ // probably because of thread scheduling. Eventually we may support out-of-
+ // order completion involving multple compute streams so here we follow a
+ // conservative approach and wait for every single callback to arrive before
+ // advancing the frontier.
+ while (true) {
+ int next_index = (last_completed_ + 1) % pending_kernels_.size();
+ if (next_index == first_available_) break;
+ if (pending_kernels_[next_index].terminated) {
+ last_completed_ = next_index;
+ } else {
+ break;
+ }
+ }
+ // Last decrease num_pending before maybe waking a waiter.
+ --num_pending_;
+ pending_decreased_.notify_one();
+}
+
+uint64 GPUKernelTracker::LastTerminatedCount() {
+ mutex_lock l(mu_);
+ if (last_completed_ < 0) {
+ // This is an edge case that can be encountered only at the beginning of
+ // execution. There's not yet a safe threshold count. We don't want to
+ // return 0 since that bypasses the count mechanism in BFCAllocator, so
+ // return the least non-zero value.
+ return 1;
+ }
+ return pending_kernels_[last_completed_].queued_count;
+}
+
} // namespace tensorflow
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.h b/tensorflow/core/common_runtime/gpu/gpu_device.h
index d002d02..33f0585 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.h
@@ -34,6 +34,7 @@
#include "tensorflow/core/common_runtime/gpu_device_context.h"
#include "tensorflow/core/common_runtime/local_device.h"
#include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
+#include "tensorflow/core/common_runtime/shared_counter.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -46,6 +47,7 @@
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
+class GPUKernelTracker;
class BaseGPUDevice : public LocalDevice {
public:
@@ -114,6 +116,17 @@
return scoped_allocator_mgr_.get();
}
+ // The following two functions always return 0 unless one of the
+ // related experimental config options has been specified.
+
+ // If returned value is > 0 then GPU Memory chunks freed before this count
+ // are guaranteed not to be in use by any kernel pending on this device.
+ uint64 SafeAllocFrontier() override;
+
+ // Returns the number of kernels that have been queued for execution on
+ // the compute stream and are not yet known to have completed.
+ int PendingKernels();
+
protected:
Allocator* gpu_allocator_; // not owned
Allocator* cpu_allocator_; // not owned
@@ -141,6 +154,9 @@
const int32 max_streams_;
std::unique_ptr<EventMgr> em_;
std::unique_ptr<thread::ThreadPool> thread_pool_;
+ std::unique_ptr<GPUKernelTracker> kernel_tracker_;
+ int pending_cap_ = 0;
+ bool timestamped_allocator_ = false;
// Initialize scractch buffers used by Eigen.
Status InitScratchBuffers();
@@ -163,6 +179,75 @@
StatusCallback done);
};
+// A per-compute-stream utility that keeps track of kernels that have been
+// queued for execution but may not yet have terminated, and also the queued
+// time of the most recently terminated kernel.
+class GPUKernelTracker {
+ public:
+ explicit GPUKernelTracker(Env* env,
+ std::unique_ptr<SharedCounter> timing_counter)
+ : env_(env),
+ timing_counter_(std::move(timing_counter)),
+ pending_kernels_(64) {}
+
+ // Record that a GPU kernel has just been enqueued on the compute stream.
+ // Inserts a new timing counter value in a new PendingKernel record appended
+ // to the end of the ring buffer then returns that same count.
+ uint64 RecordQueued();
+
+ // Takes a count value returned by RecordQueued and finds the corresponding
+ // PendingKernel record in the ring buffer. Marks the kernel as completed and
+ // advances the completion frontier accordingly.
+ void RecordTerminated(uint64 at_count);
+
+ // Returns the largest timing count such that all kernels queued no
+ // later than that count are known to have terminated.
+ uint64 LastTerminatedCount();
+
+ // Returns the number of kernels enqueued that are not yet known to
+ // have terminated.
+ int NumPending() {
+ mutex_lock l(mu_);
+ return num_pending_;
+ }
+
+ // Yield current thread until number of pending kernels no longer
+ // exceeds the cap.
+ void PauseWhilePendingExceeds(int cap) {
+ mutex_lock l(mu_);
+ while (num_pending_ > cap) {
+ pending_decreased_.wait(l);
+ }
+ }
+
+ private:
+ Env* env_;
+ std::unique_ptr<SharedCounter> timing_counter_;
+
+ // Records when a kernel was queued for execution. Kernel launches are
+ // identified by a unique count value from a per-GPU device timing counter.
+ struct PendingKernel {
+ uint64 queued_count;
+ bool terminated;
+ PendingKernel(const PendingKernel& pk)
+ : queued_count(pk.queued_count), terminated(pk.terminated) {}
+ PendingKernel() : queued_count(0), terminated(false) {}
+ };
+ mutex mu_;
+ // Ring buffer of PendingKernel records.
+ std::vector<PendingKernel> pending_kernels_ GUARDED_BY(mu_);
+ // Next unused slot in pending_kernels_.
+ int first_available_ GUARDED_BY(mu_) = 0;
+ // Last completed PendingKernel such that all prior PendingKernels are
+ // also completed. With out-of-order completion there may be a mixture
+ // of completed and uncompleted entries between last_completed_ and
+ // first_available_, hence num_pending_ is not guaranteed equal to
+ // their differerence.
+ int last_completed_ GUARDED_BY(mu_) = -1;
+ int num_pending_ GUARDED_BY(mu_) = 0;
+ condition_variable pending_decreased_ GUARDED_BY(mu_);
+};
+
class BaseGPUDeviceFactory : public DeviceFactory {
public:
Status CreateDevices(const SessionOptions& options, const string& name_prefix,
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
index ae623b2..fba937a 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
@@ -24,6 +24,7 @@
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@@ -276,6 +277,71 @@
allocator->DeallocateRaw(ptr);
}
+class GPUKernelTrackerTest : public ::testing::Test {
+ protected:
+ void SetUp() {
+ std::unique_ptr<SharedCounter> counter(new SharedCounter);
+ timing_counter_ = counter.get();
+ kernel_tracker_.reset(
+ new GPUKernelTracker(Env::Default(), std::move(counter)));
+ }
+
+ std::unique_ptr<GPUKernelTracker> kernel_tracker_;
+ SharedCounter* timing_counter_ = nullptr;
+};
+
+TEST_F(GPUKernelTrackerTest, basic) {
+ EXPECT_EQ(0, kernel_tracker_->NumPending());
+ // 1 is the expected value when no kernels have yet terminated.
+ EXPECT_EQ(1, kernel_tracker_->LastTerminatedCount());
+
+ std::deque<int64> queued_counts;
+ for (int i = 0; i < 32; ++i) {
+ queued_counts.push_back(kernel_tracker_->RecordQueued());
+ }
+ EXPECT_EQ(32, kernel_tracker_->NumPending());
+ EXPECT_EQ(1, kernel_tracker_->LastTerminatedCount());
+
+ // Mature the kernels in order until empty.
+ while (!queued_counts.empty()) {
+ int64 x = queued_counts.front();
+ queued_counts.pop_front();
+ kernel_tracker_->RecordTerminated(x);
+ EXPECT_EQ(queued_counts.size(), kernel_tracker_->NumPending());
+ EXPECT_EQ(x, kernel_tracker_->LastTerminatedCount());
+ }
+ EXPECT_EQ(timing_counter_->get(), kernel_tracker_->LastTerminatedCount());
+
+ // Next inject so many kernel events that the ring buffer needs
+ // to grow a couple of times, while maturing a few in random order
+ // to introduce gaps between last_completed_ and first_available_.
+ int64 lower_bound = timing_counter_->get();
+ for (int i = 0; i < 1111; ++i) {
+ queued_counts.push_back(kernel_tracker_->RecordQueued());
+ int64 upper_bound = timing_counter_->get();
+ if (0 == (i % 16)) {
+ size_t index = (random::New64() % queued_counts.size());
+ kernel_tracker_->RecordTerminated(queued_counts[index]);
+ queued_counts.erase(queued_counts.begin() + index);
+ EXPECT_LE(lower_bound, kernel_tracker_->LastTerminatedCount());
+ EXPECT_GE(upper_bound, kernel_tracker_->LastTerminatedCount());
+ }
+ }
+
+ // Next mature the remaining kernels in order until empty.
+ while (!queued_counts.empty()) {
+ int64 x = queued_counts.front();
+ queued_counts.pop_front();
+ kernel_tracker_->RecordTerminated(x);
+ EXPECT_EQ(queued_counts.size(), kernel_tracker_->NumPending());
+ // There may be a gap here where we find a kernel that got terminated
+ // out of order, earlier, so the LastTerminatedCount can actually
+ // jump past x.
+ EXPECT_LE(x, kernel_tracker_->LastTerminatedCount());
+ }
+ EXPECT_EQ(timing_counter_->get(), kernel_tracker_->LastTerminatedCount());
+}
+
} // namespace tensorflow
#endif
diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc
index 3c1c31a..6531d6d 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc
@@ -241,7 +241,9 @@
// events have recorded, and then retire them. Initial observations
// suggest that typical behavior in a TensorFlow program is to have
// 0-3 events pending most of the time, but there are occasionally
-// spikes of up to several hundred outstanding.
+// spikes of up to several hundred outstanding. (If GPUKernelTracker
+// is used to cap pending kernels there should never be more than
+// that many.)
//
// NOTE: If all events are on the same stream, no later event will
// complete before an earlier event, except possibly if the earlier
@@ -249,13 +251,10 @@
// looking past the first kPending event. However, if we're using
// multiple streams there may be some gain in looking deeper.
// As a compromise, PollEvent() calls that are triggered by the queueing
-// of a single event never look past the first kPending event. Calls
-// coming from the dedicated polling thread always sweep the full queue.
-//
-// Note that allowing the queue to grow very long could cause overall
-// GPU memory use to spike needlessly. An alternative strategy would
-// be to throttle new Op execution until the pending event queue
-// clears.
+// of a single event never look past the first kPending event. Consequently
+// those calls do an expected constant amount of work, unaffected by the
+// length of the pending queue. Calls coming from the dedicated
+// polling thread always sweep the full queue.
void EventMgr::PollEvents(bool is_dedicated_poller,
gtl::InlinedVector<InUse, 4>* to_free) {
VLOG(2) << "PollEvents free_events_ " << free_events_.size()
diff --git a/tensorflow/core/common_runtime/gpu/gpu_managed_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_managed_allocator.cc
index 613633e..2375896 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_managed_allocator.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_managed_allocator.cc
@@ -14,6 +14,7 @@
==============================================================================*/
#ifdef GOOGLE_CUDA
+#include "tensorflow/stream_executor/cuda/cuda_driver_wrapper.h"
#define EIGEN_USE_GPU
#endif
@@ -24,7 +25,11 @@
void* GpuManagedAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
void* ptr = nullptr;
#ifdef GOOGLE_CUDA
- CHECK_EQ(cudaMallocManaged(&ptr, num_bytes), cudaSuccess);
+ CUdeviceptr result = 0;
+ CHECK_EQ(tensorflow::wrap::cuMemAllocManaged(&result, num_bytes,
+ CU_MEM_ATTACH_GLOBAL),
+ CUDA_SUCCESS);
+ ptr = reinterpret_cast<void*>(result);
#endif
CHECK(!(reinterpret_cast<uintptr_t>(ptr) & (alignment - 1)));
return ptr;
diff --git a/tensorflow/core/common_runtime/gpu/gpu_process_state.cc b/tensorflow/core/common_runtime/gpu/gpu_process_state.cc
index 8167cfb..7804596 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_process_state.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_process_state.cc
@@ -27,6 +27,7 @@
#include "tensorflow/core/common_runtime/gpu/gpu_id_utils.h"
#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
#include "tensorflow/core/common_runtime/pool_allocator.h"
+#include "tensorflow/core/common_runtime/shared_counter.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/tracking_allocator.h"
@@ -90,7 +91,7 @@
}
AllocatorParts& allocator_parts = gpu_allocators_[tf_gpu_id.value()];
- if (allocator_parts.allocator.get() == nullptr) {
+ if (allocator_parts.allocator == nullptr) {
// Validate allocator types.
if (!allocator_type.empty() && allocator_type != "BFC") {
LOG(ERROR) << "Invalid allocator type: " << allocator_type;
@@ -110,9 +111,15 @@
(options.per_process_gpu_memory_fraction() > 1.0 ||
options.experimental().use_unified_memory()),
gpu_visitors_[bus_id], {});
- Allocator* gpu_allocator =
+ GPUBFCAllocator* gpu_bfc_allocator =
new GPUBFCAllocator(sub_allocator, total_bytes, options,
strings::StrCat("GPU_", tf_gpu_id.value(), "_bfc"));
+ Allocator* gpu_allocator = gpu_bfc_allocator;
+ SharedCounter* timing_counter = nullptr;
+ if (options.experimental().timestamped_allocator()) {
+ timing_counter = new SharedCounter;
+ gpu_bfc_allocator->SetTimingCounter(timing_counter);
+ }
// If true, checks for memory overwrites by writing
// distinctive patterns on both ends of allocated memory.
@@ -137,7 +144,9 @@
recording_allocator = new internal::RecordingAllocator(
&process_state_->mem_desc_map_, gpu_allocator, md, &mu_);
}
- allocator_parts = {std::unique_ptr<Allocator>(gpu_allocator), sub_allocator,
+ allocator_parts = {std::unique_ptr<Allocator>(gpu_allocator),
+ std::unique_ptr<SharedCounter>(timing_counter),
+ sub_allocator,
std::unique_ptr<Allocator>(recording_allocator)};
}
if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
@@ -151,6 +160,23 @@
#endif // GOOGLE_CUDA
}
+std::unique_ptr<SharedCounter> GPUProcessState::ReleaseGPUAllocatorCounter(
+ TfGpuId tf_gpu_id) {
+ DCHECK(process_state_);
+#if GOOGLE_CUDA
+ GpuIdUtil::CheckValidTfGpuId(tf_gpu_id);
+ mutex_lock l(mu_);
+ if (tf_gpu_id.value() >= static_cast<int64>(gpu_allocators_.size())) {
+ return nullptr;
+ }
+
+ AllocatorParts& allocator_parts = gpu_allocators_[tf_gpu_id.value()];
+ return std::move(allocator_parts.counter);
+#else
+ return nullptr;
+#endif
+}
+
Allocator* GPUProcessState::GetCUDAHostAllocator(int numa_node) {
CHECK(process_state_);
if (!HasGPUDevice() ||
@@ -224,6 +250,7 @@
allocator = new TrackingAllocator(allocator, true);
}
cuda_host_allocators_.push_back({std::unique_ptr<Allocator>(allocator),
+ std::unique_ptr<SharedCounter>(nullptr),
sub_allocator,
std::unique_ptr<Allocator>(nullptr)});
AllocatorParts& allocator_parts = cuda_host_allocators_.back();
diff --git a/tensorflow/core/common_runtime/gpu/gpu_process_state.h b/tensorflow/core/common_runtime/gpu/gpu_process_state.h
index df51c10..c7c9f3a 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_process_state.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_process_state.h
@@ -23,6 +23,7 @@
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
#include "tensorflow/core/common_runtime/process_state.h"
+#include "tensorflow/core/common_runtime/shared_counter.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
@@ -33,6 +34,7 @@
class Allocator;
class PoolAllocator;
+class SharedCounter;
// Singleton that manages per-process state when GPUs are present.
class GPUProcessState {
@@ -108,6 +110,8 @@
// Returns bus_id for the given GPU id.
virtual int BusIdForGPU(TfGpuId tf_gpu_id);
+ std::unique_ptr<SharedCounter> ReleaseGPUAllocatorCounter(TfGpuId tf_gpu_id);
+
protected:
// GPUProcessState is a singleton that should not normally be deleted except
// at process shutdown.
@@ -132,6 +136,7 @@
struct AllocatorParts {
std::unique_ptr<Allocator> allocator;
+ std::unique_ptr<SharedCounter> counter;
SubAllocator* sub_allocator; // owned by allocator
std::unique_ptr<Allocator> recording_allocator;
};
diff --git a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc
index f0656ff..12af4a8 100644
--- a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc
+++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc
@@ -616,7 +616,7 @@
auto* dev_info = device_->tensorflow_gpu_device_info();
CHECK(dev_info);
dev_info->default_context->CopyCPUTensorToDevice(
- &cpu_tensor, device_, &tensor_, [this, ¬ification](Status s) {
+ &cpu_tensor, device_, &tensor_, [¬ification](Status s) {
TF_CHECK_OK(s);
notification.Notify();
});
diff --git a/tensorflow/core/common_runtime/placer.cc b/tensorflow/core/common_runtime/placer.cc
index 8c0a3fb..c9827d1 100644
--- a/tensorflow/core/common_runtime/placer.cc
+++ b/tensorflow/core/common_runtime/placer.cc
@@ -20,6 +20,7 @@
#include <utility>
#include <vector>
+#include "absl/strings/str_join.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
@@ -32,6 +33,8 @@
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/util/device_name_utils.h"
+#include "tensorflow/core/util/dump_graph.h"
#include "tensorflow/core/util/port.h"
namespace tensorflow {
@@ -95,6 +98,27 @@
return filtered_devices;
}
+// Using absl::StrJoin with lambda does not work in tf-lite builds.
+std::vector<string> DevicesToString(const std::vector<Device*> devices) {
+ std::vector<string> v;
+ v.reserve(devices.size());
+ for (Device* d : devices) {
+ v.push_back(d->name());
+ }
+ return v;
+}
+
+// Using absl::StrJoin with lambda does not work in tf-lite builds.
+std::vector<string> DeviceTypeAndPriorityToString(
+ const PrioritizedDeviceTypeVector& devices) {
+ std::vector<string> v;
+ v.reserve(devices.size());
+ for (const std::pair<DeviceType, int32>& device_and_type : devices) {
+ v.push_back(DeviceTypeString(device_and_type.first));
+ }
+ return v;
+}
+
// Returns true if the node has no inputs and produces outputs
// that are consumed by a single node.
//
@@ -106,6 +130,16 @@
!IsRefType(node->output_type(0));
}
+// While Placer can override requested device on ops processing
+// resources, i.e. node that take (and potentially return) a resource,
+// it must not override requested device on ops generating a resource,
+// e.g. VarHandleOp, _Arg. Such ops are currently no-input, single resource/ref
+// output nodes.
+bool IsResourceGeneratorNode(const Node& node) {
+ return node.num_inputs() == 0 && node.num_outputs() == 1 &&
+ (IsRefType(node.output_type(0)) || node.output_type(0) == DT_RESOURCE);
+}
+
bool IsExemptFromResourceInputColocation(const Node* node) {
// Note: Partitioned function calls, which place and partition their
// function bodies, are exempt from this check: they forward resource and
@@ -115,6 +149,374 @@
return op_type == "PartitionedCall" || op_type == "StatefulPartitionedCall";
}
+bool HasPriorities(const PrioritizedDeviceTypeVector& device_types) {
+ for (const auto& prioritized_device_type : device_types) {
+ if (prioritized_device_type.second != 0) return true;
+ }
+ return false;
+}
+
+bool ArePrioritiesSame(const PrioritizedDeviceTypeVector& a_types,
+ const PrioritizedDeviceTypeVector& b_types) {
+ if (a_types.size() != b_types.size()) {
+ return false;
+ }
+ for (int i = 0; i < a_types.size(); ++i) {
+ if (a_types[i].first != b_types[i].first) {
+ return false;
+ }
+ }
+ return true;
+}
+
+// Represents a node in the disjoint node forest and the
+// accumulated constraints on the device used by that node.
+class Member {
+ public:
+ Member() = default;
+
+ Status SetParentAndSupportedDevices(const Node& node,
+ const std::vector<DeviceType>& types) {
+ int id = node.id();
+ if (id < 0) {
+ return errors::Internal(
+ "Placer should not be creating a Member for node: ",
+ node.DebugString());
+ }
+ parent_ = id;
+ return SupportedDeviceTypesForNode(types, node.def(),
+ &supported_device_types_);
+ }
+
+ const DeviceNameUtils::ParsedName& requested_device_name() const {
+ return requested_device_name_;
+ }
+
+ Status SetAssignedDeviceName(const string& device_name) {
+ if (DeviceNameUtils::HasSomeDetails(requested_device_name_)) {
+ return errors::Internal(
+ "Setting assigned device name when there is a requested device set "
+ "is unsupported");
+ }
+ if (!DeviceNameUtils::ParseFullName(device_name, &assigned_device_name_)) {
+ return errors::Internal("Malformed assigned device '", device_name, "'");
+ }
+ // Set requested device to assigned_device to maintain the invariant that
+ // requested is a specialization of assigned.
+ requested_device_name_ = assigned_device_name_;
+ return Status::OK();
+ }
+
+ Status SetRequestedDeviceName(const Node& node) {
+ if (!DeviceNameUtils::ParseFullName(node.requested_device(),
+ &requested_device_name_)) {
+ return errors::InvalidArgument("Malformed device specification '",
+ node.requested_device(),
+ "' in node: ", node.DebugString());
+ }
+ if (DeviceNameUtils::HasSomeDetails(assigned_device_name_)) {
+ return errors::Internal(
+ "Setting requested device name when there is an assigned device set "
+ "is unsupported");
+ }
+ return Status::OK();
+ }
+
+ Status EnsureCompatibilityAcrossResourceEdge(
+ const Node& src, const Member& src_root,
+ const Node& dst, /*dst_root is this*/
+ bool log_device_placement) {
+ if (!DeviceNameUtils::AreCompatibleDevNames(src_root.assigned_device_name_,
+ assigned_device_name_)) {
+ return errors::InvalidArgument(
+ "Cannot place the graph because a reference or resource edge "
+ "connects colocation groups with incompatible assigned devices: ",
+ DeviceNameUtils::ParsedNameToString(src_root.assigned_device_name_),
+ " vs ", DeviceNameUtils::ParsedNameToString(assigned_device_name_));
+ }
+
+ if (DeviceNameUtils::AreCompatibleDevNames(src_root.requested_device_name_,
+ requested_device_name_)) {
+ return Status::OK();
+ }
+
+ // If we are here, assigned devices are compatible but requested ones are
+ // not. We will be overriding the requested device for destination node, but
+ // need to preserve the invariant that it will be a specialization of
+ // the assigned device.
+ if (log_device_placement) {
+ LOG(INFO) << "Ignoring device specification "
+ << DeviceNameUtils::ParsedNameToString(requested_device_name_)
+ << " for node '" << dst.name()
+ << "' because the input edge from '" << src.name()
+ << "' is a reference connection and already has a device "
+ "field set to "
+ << DeviceNameUtils::ParsedNameToString(
+ src_root.requested_device_name_);
+ }
+ requested_device_name_ = src_root.requested_device_name_;
+ DeviceNameUtils::EnsureSpecification(&requested_device_name_,
+ assigned_device_name_);
+ return Status::OK();
+ }
+
+ const PrioritizedDeviceTypeVector& supported_device_types() const {
+ return supported_device_types_;
+ }
+
+ static void Merge(std::vector<Member>* tree, int x_root, int y_root,
+ Member** new_root, Member** old_root) {
+ Member& x_root_member = (*tree)[x_root];
+ Member& y_root_member = (*tree)[y_root];
+
+ // Merge the sets by setting the parent pointer of the smaller tree's root
+ // node to point to the root of the larger tree. Together with path
+ // compression in ColocationGraph::FindRoot, this ensures that we do not
+ // experience pathological performance on graphs such as chains.
+ int new_root_id, old_root_id;
+ if (x_root_member.rank_ < y_root_member.rank_) {
+ // The tree rooted at x_root is shallower, so connect it to
+ // y_root. The rank of y_root is unchanged because its new
+ // child has strictly less rank.
+ x_root_member.parent_ = y_root;
+ new_root_id = y_root;
+ old_root_id = x_root;
+ } else if (x_root_member.rank_ > y_root_member.rank_) {
+ // The tree rooted at y_root is shallower, so connect it to
+ // x_root. The rank of x_root is unchanged because its new
+ // child has strictly less rank.
+ y_root_member.parent_ = x_root;
+ new_root_id = x_root;
+ old_root_id = y_root;
+ } else {
+ // Both trees have the same rank, so break the tie by choosing
+ // x_root as the new root.
+ y_root_member.parent_ = x_root;
+ // Increment the rank of the tree rooted at x_root, because it
+ // is now strictly deeper than before.
+ ++x_root_member.rank_;
+ new_root_id = x_root;
+ old_root_id = y_root;
+ }
+
+ *new_root = &(*tree)[new_root_id];
+ *old_root = &(*tree)[old_root_id];
+ }
+
+ // tree is non-const because we can change some `parent` pointers in some
+ // members for more efficient future lookups. The vector itself is not
+ // changed.
+ static int FindRoot(std::vector<Member>* tree, int node_id) {
+ Member& member = (*tree)[node_id];
+ if (member.parent_ == node_id) {
+ // member.parent is the root of this disjoint tree. Do nothing.
+ } else {
+ member.parent_ = FindRoot(tree, member.parent_);
+ }
+ // Now it is guaranteed that member.parent is the root of this disjoint
+ // tree.
+ return member.parent_;
+ }
+
+ Status MergeDeviceNames(const Member& other, bool allow_soft_placement) {
+ // Assuming the "requested is a specialization of assigned" invariant holds
+ // for this and `other`, it will hold after these two merges.
+ TF_RETURN_IF_ERROR(DeviceNameUtils::MergeDevNames(
+ &requested_device_name_, other.requested_device_name_,
+ allow_soft_placement));
+ return DeviceNameUtils::MergeDevNames(&assigned_device_name_,
+ other.assigned_device_name_,
+ allow_soft_placement);
+ }
+
+ // Updates this to contain the intersection of the device types in
+ // this and "other".
+ void MergeSupportedDevices(const Member& other) {
+ PrioritizedDeviceTypeVector temp = supported_device_types_;
+ supported_device_types_.clear();
+
+ // Generate intersection with priorities.
+ PrioritizedDeviceTypeVector target_intersection;
+ PrioritizedDeviceTypeVector other_intersection;
+ for (const auto& prioritized_device_type : temp) {
+ bool found = false;
+ for (const auto& other_prioritized_device_type :
+ other.supported_device_types_) {
+ if (prioritized_device_type.first ==
+ other_prioritized_device_type.first) {
+ found = true;
+ other_intersection.push_back(other_prioritized_device_type);
+ break;
+ }
+ }
+ if (found) {
+ target_intersection.push_back(prioritized_device_type);
+ }
+ }
+
+ // Sort the devices by priority order.
+ auto device_sort = [](const std::pair<DeviceType, int32>& a,
+ const std::pair<DeviceType, int32>& b) {
+ // First look at set priorities.
+ if (a.second != b.second) {
+ return a.second > b.second;
+ }
+ // Then fallback to default priorities.
+ auto a_priority = DeviceSet::DeviceTypeOrder(a.first);
+ auto b_priority = DeviceSet::DeviceTypeOrder(b.first);
+ if (a_priority != b_priority) {
+ return a_priority > b_priority;
+ }
+ // Finally just look at the Device type strings.
+ return a.first.type_string() < b.first.type_string();
+ };
+
+ std::sort(target_intersection.begin(), target_intersection.end(),
+ device_sort);
+ std::sort(other_intersection.begin(), other_intersection.end(),
+ device_sort);
+
+ bool is_target_prioritized = HasPriorities(target_intersection);
+ bool is_other_prioritized = HasPriorities(other_intersection);
+ // If neither are prioritized then we just return the original i.e. target
+ // prioritization.
+ if (!is_target_prioritized && !is_other_prioritized) {
+ supported_device_types_ = target_intersection;
+ }
+ // If only one is prioritized, then we respect priorities of that in the
+ // intersection.
+ if (is_target_prioritized && !is_other_prioritized) {
+ supported_device_types_ = target_intersection;
+ }
+ if (!is_target_prioritized && is_other_prioritized) {
+ supported_device_types_ = other_intersection;
+ }
+ // If both have priorities and agree then we go with that. If the
+ // prioritization order is different, then we just fallback to the default
+ // i.e. what the DeviceTypeOrder suggests. In that case, we also set the
+ // merged priorities to 0, so that downstream merges work correctly as well.
+ if (is_target_prioritized && is_other_prioritized) {
+ bool priorities_agree =
+ ArePrioritiesSame(target_intersection, other_intersection);
+ if (priorities_agree) {
+ supported_device_types_ = target_intersection;
+ } else {
+ for (const auto& prioritized_device : target_intersection) {
+ supported_device_types_.push_back(
+ std::make_pair(prioritized_device.first, 0));
+ }
+ std::sort(supported_device_types_.begin(),
+ supported_device_types_.end(), device_sort);
+ }
+ }
+ }
+
+ Status AssignDevice(const Node& node, bool allow_soft_placement) {
+ if (node.assigned_device_name_index() == assigned_device_name_index_) {
+ return Status::OK();
+ }
+
+ DeviceNameUtils::ParsedName parsed;
+ DeviceNameUtils::ParseFullName(node.assigned_device_name(), &parsed);
+ Status s = DeviceNameUtils::MergeDevNames(&assigned_device_name_, parsed,
+ allow_soft_placement);
+ if (!s.ok()) {
+ return errors::Internal(
+ "Constraining by assigned device should not cause an error. Original "
+ "root's assigned device name: ",
+ DeviceNameUtils::ParsedNameToString(assigned_device_name_),
+ " node's assigned device name \"", node.assigned_device_name(),
+ ". Error: ", s.error_message());
+ }
+ s = DeviceNameUtils::MergeDevNames(&requested_device_name_, parsed,
+ allow_soft_placement);
+ if (!s.ok()) {
+ return errors::Internal(
+ "Constraining by assigned device should not cause an error. Original "
+ "root's requested device name: \"",
+ DeviceNameUtils::ParsedNameToString(requested_device_name_),
+ "\", node's assigned device name \"", node.assigned_device_name(),
+ "\". Error: ", s.error_message());
+ }
+
+ assigned_device_name_index_ = node.assigned_device_name_index();
+ // Clear cached possible_devices, if any.
+ possible_devices_.clear();
+ return Status::OK();
+ }
+
+ void set_possible_devices(std::vector<Device*>&& devices) {
+ possible_devices_ = devices;
+ }
+ const std::vector<Device*>& possible_devices() { return possible_devices_; }
+
+ string DebugString() {
+ return absl::StrCat(
+ "Member(assigned_device_name_index_=", assigned_device_name_index_,
+ " requested_device_name_=",
+ DeviceNameUtils::ParsedNameToString(requested_device_name_),
+ " assigned_device_name_=",
+ DeviceNameUtils::ParsedNameToString(assigned_device_name_),
+ " supported_device_types_=[",
+ absl::StrJoin(DeviceTypeAndPriorityToString(supported_device_types_),
+ ", "),
+ "] possible_devices_=[",
+ absl::StrJoin(DevicesToString(possible_devices_), ", "), "]");
+ }
+
+ private:
+ // The id of the node that is the parent of this one, or its own
+ // id if it is a root. parent <= 0 indicates that this member is invalid.
+ int parent_ = -1;
+
+ // A proxy for the depth of the tree that is used to prefer
+ // connecting smaller trees to larger trees when merging disjoint
+ // sets.
+ int rank_ = 0;
+
+ // Once colocation groups have been formed, the Placer starts actually
+ // choosing devices. All nodes in a group must be assigned to the same
+ // device. Once we assigned the first device to some node in this group,
+ // we set assigned_device_name_index to this device name's index in the
+ // graph.
+ // The `*_device_name_` fields will contain the parsed name of this device
+ // and `possible_devices`, if computed, will contain just this device.
+ // `assigned_device_name_index` is an optimization to avoid parsing and
+ // comparing device names. The value of -1 signals that a single device
+ // has not been chosen yet.
+ int assigned_device_name_index_ = -1;
+
+ // The merged form of the device requested for this node, with those of all of
+ // its children. requested_device_name_ is always kept a specialization (i.e.
+ // DeviceNameUtils::IsSpecialization) of assigned_device_name_. When no device
+ // is requested, this field is set to assigned_device_name_. As a
+ // specialization of assigned_device_name_, requested_device_name_ represents
+ // the most specific form of all assigned and requested devices of this node
+ // and its children, if this node is a root. requested_device_name_ is used
+ // to finally select devices for nodes. We can override requested devices due
+ // to resource colocation constraints but not assigned devices (unless soft
+ // placement is on).
+ DeviceNameUtils::ParsedName requested_device_name_;
+
+ // The merged form of the device assigned for this node, with
+ // those of all of its children.
+ // This field is used to raise errors due to unsatisfiable constraints.
+ // Can be a partial specification.
+ // INVARIANT: requested_device_name_ is always a
+ // DeviceNameUtils::IsSpecialization of assigned_device_name_.
+ DeviceNameUtils::ParsedName assigned_device_name_;
+
+ // The intersection of all device types supported by this node,
+ // and those of all of its children, in priority order
+ // of the preferred device.
+ PrioritizedDeviceTypeVector supported_device_types_;
+
+ // If this node is a root, stores a list of Devices to which this node
+ // and all of its children have been assigned, or nullptr if this
+ // has not yet been computed.
+ std::vector<Device*> possible_devices_;
+}; // namespace
+
// This class maintains the connected components of a colocation
// constraint graph, and uses this information to assign a satisfying
// device placement to the nodes of the graph.
@@ -227,34 +629,9 @@
int dst_root_id = FindRoot(dst->id());
auto& src_root = members_[src_root_id];
auto& dst_root = members_[dst_root_id];
- // If both the source node and this node have partially
- // specified a device, then 'dst's device should be
- // cleared: the reference edge forces 'node' to be on the
- // same device as the source node.
- const auto& source_parsed_name = src_root.device_name;
- const auto& dest_parsed_name = dst_root.device_name;
- if (DeviceNameUtils::HasSomeDetails(source_parsed_name) &&
- DeviceNameUtils::HasSomeDetails(dest_parsed_name)) {
- // Ignore a specified device for 'dst' if the two names were
- // incompatible.
- if (!DeviceNameUtils::AreCompatibleDevNames(source_parsed_name,
- dest_parsed_name)) {
- TF_RETURN_IF_ERROR(VerifyResourceAndRefInputsCanBeColocated(
- dst, src, source_parsed_name));
- if (log_device_placement_) {
- LOG(INFO) << "Ignoring device specification "
- << DeviceNameUtils::ParsedNameToString(dest_parsed_name)
- << " for node '" << dst->name()
- << "' because the input edge from '" << src->name()
- << "' is a reference connection and already has a device "
- "field set to "
- << DeviceNameUtils::ParsedNameToString(source_parsed_name);
- }
- // Make 'dst' colocated with the source
- dst_root.device_name = source_parsed_name;
- }
- }
+ TF_RETURN_IF_ERROR(dst_root.EnsureCompatibilityAcrossResourceEdge(
+ *src, src_root, *dst, log_device_placement_));
Status status = ColocateNodes(*src, src_root_id, *dst, dst_root_id);
if (!status.ok()) {
return AttachDef(
@@ -337,50 +714,18 @@
DCHECK_EQ(x_root, FindRoot(x.id()));
DCHECK_EQ(y_root, FindRoot(y.id()));
- Member& x_root_member = members_[x_root];
- Member& y_root_member = members_[y_root];
-
- // Merge the sets by setting the parent pointer of the smaller tree's root
- // node to point to the root of the larger tree. Together with path
- // compression in ColocationGraph::FindRoot, this ensures that we do not
- // experience pathological performance on graphs such as chains.
- int new_root, old_root;
- if (x_root_member.rank < y_root_member.rank) {
- // The tree rooted at x_root is shallower, so connect it to
- // y_root. The rank of y_root is unchanged because its new
- // child has strictly less rank.
- x_root_member.parent = y_root;
- new_root = y_root;
- old_root = x_root;
- } else if (x_root_member.rank > y_root_member.rank) {
- // The tree rooted at y_root is shallower, so connect it to
- // x_root. The rank of x_root is unchanged because its new
- // child has strictly less rank.
- y_root_member.parent = x_root;
- new_root = x_root;
- old_root = y_root;
- } else {
- // Both trees have the same rank, so break the tie by choosing
- // x_root as the new root.
- y_root_member.parent = x_root;
- // Increment the rank of the tree rooted at x_root, because it
- // is now strictly deeper than before.
- ++x_root_member.rank;
- new_root = x_root;
- old_root = y_root;
- }
-
- Member& new_root_member = members_[new_root];
- Member& old_root_member = members_[old_root];
+ Member* new_root_member;
+ Member* old_root_member;
+ Member::Merge(&members_, x_root, y_root, &new_root_member,
+ &old_root_member);
// Merge the partial device specifications, and ensure that they are
// compatible. NULL options_ is treated as allowing soft placement.
// TODO(mrry): Consider enriching the error message by pointing
// out which nodes have the explicit partial device
// specifications that caused this conflict.
- Status s = DeviceNameUtils::MergeDevNames(&new_root_member.device_name,
- old_root_member.device_name,
- allow_soft_placement_);
+ Status s = new_root_member->MergeDeviceNames(*old_root_member,
+ allow_soft_placement_);
if (!s.ok()) {
return errors::InvalidArgument(
"Cannot colocate nodes ",
@@ -393,9 +738,8 @@
// type, by computing the intersection of
// new_root_member.supported_device_types and
// old_root_member.supported_device_types.
- MergeSupportedDevices(&new_root_member.supported_device_types,
- old_root_member.supported_device_types);
- if (new_root_member.supported_device_types.empty()) {
+ new_root_member->MergeSupportedDevices(*old_root_member);
+ if (new_root_member->supported_device_types().empty()) {
return errors::InvalidArgument(
"Cannot colocate nodes ",
errors::FormatColocationNodeForError(x.name()), " and ",
@@ -422,28 +766,7 @@
}
int root = FindRoot(node.id());
Member& root_member = members_[root];
- if (node.assigned_device_name_index() ==
- root_member.assigned_device_name_index) {
- return Status::OK();
- }
- DeviceNameUtils::ParsedName parsed;
- DeviceNameUtils::ParseFullName(node.assigned_device_name(), &parsed);
- Status s = DeviceNameUtils::MergeDevNames(&root_member.device_name, parsed,
- allow_soft_placement_);
- if (!s.ok()) {
- return errors::Internal(
- "Constraining by assigned device should not cause an error. Original "
- "root device name: ",
- DeviceNameUtils::ParsedNameToString(root_member.device_name),
- " assigned device name \"", node.assigned_device_name(),
- ". Error: ", s.error_message());
- }
-
- root_member.assigned_device_name_index = node.assigned_device_name_index();
- // Clear cached possible_devices, if any.
- root_member.possible_devices.clear();
-
- return Status::OK();
+ return root_member.AssignDevice(node, allow_soft_placement_);
}
// For the given node, subject to the constraints previously given
@@ -454,11 +777,11 @@
// The caller must not use the returned pointer after there is any possibility
// that the members_[i].possible_devices field has been modified.
Status GetDevicesForNode(Node* node,
- std::vector<Device*>** possible_devices) {
+ const std::vector<Device*>** possible_devices) {
*possible_devices = nullptr;
const int node_root = FindRoot(node->id());
- if (!members_[node_root].possible_devices.empty()) {
- *possible_devices = &members_[node_root].possible_devices;
+ if (!members_[node_root].possible_devices().empty()) {
+ *possible_devices = &members_[node_root].possible_devices();
return Status::OK();
}
@@ -469,18 +792,19 @@
// "devices" will contain the set of feasible placements for the
// colocated node set containing 'node'.
std::vector<Device*> devices;
- if (DeviceNameUtils::HasSomeDetails(members_[node_root].device_name)) {
+ if (DeviceNameUtils::HasSomeDetails(
+ members_[node_root].requested_device_name())) {
// The root node has a (possibly partial) device
// specification, so enumerate the physical devices that
// conform to it.
- device_set_->FindMatchingDevices(members_[node_root].device_name,
- &devices);
+ device_set_->FindMatchingDevices(
+ members_[node_root].requested_device_name(), &devices);
if (!devices.empty()) {
// Filter devices into those that are compatible with the root
// node (and its children).
devices = FilterSupportedDevices(
- devices, members_[node_root].supported_device_types,
+ devices, members_[node_root].supported_device_types(),
default_device_);
}
@@ -489,14 +813,14 @@
// The soft_device_name is the same as the node's device name
// without specifying the device type or ID.
DeviceNameUtils::ParsedName soft_device_name =
- members_[node_root].device_name;
+ members_[node_root].requested_device_name();
soft_device_name.type.clear();
soft_device_name.has_type = false;
soft_device_name.has_id = false;
device_set_->FindMatchingDevices(soft_device_name, &devices);
if (!devices.empty()) {
devices = FilterSupportedDevices(
- devices, members_[node_root].supported_device_types,
+ devices, members_[node_root].supported_device_types(),
default_device_);
}
}
@@ -510,7 +834,8 @@
DeviceNameUtils::ParsedName specified_device_name;
if (DeviceNameUtils::ParseFullName(node->requested_device(),
&specified_device_name) &&
- specified_device_name == members_[node_root].device_name) {
+ specified_device_name ==
+ members_[node_root].requested_device_name()) {
// The specified device and merged set device match, and
// will appear in the GraphDef (for debugging), so just
// print the specified device.
@@ -562,7 +887,7 @@
" was colocated with a group of nodes that ",
"required incompatible device '",
DeviceNameUtils::ParsedNameToString(
- members_[node_root].device_name),
+ members_[node_root].requested_device_name()),
"'", debug_info);
}
}
@@ -573,7 +898,7 @@
return errors::Internal("No devices are registered");
}
devices = FilterSupportedDevices(
- device_set_->devices(), members_[node_root].supported_device_types,
+ device_set_->devices(), members_[node_root].supported_device_types(),
default_device_);
if (devices.empty()) {
@@ -585,16 +910,13 @@
}
// Cache the result of the possible devices for this node group.
- members_[node_root].possible_devices = std::move(devices);
- *possible_devices = &members_[node_root].possible_devices;
+ members_[node_root].set_possible_devices(std::move(devices));
+ *possible_devices = &members_[node_root].possible_devices();
return Status::OK();
}
Status InitializeMembers() {
- for (Node* node : graph_->nodes()) {
- if (!node->IsOp()) {
- continue;
- }
+ for (Node* node : graph_->op_nodes()) {
Status status = InitializeMember(*node, &members_[node->id()]);
if (!status.ok()) {
return AttachDef(status, *node);
@@ -603,43 +925,21 @@
return Status::OK();
}
- // Represents a node in the disjoint node set forest, and the
- // accumulated constraints on the device used by that node.
- struct Member {
- Member() = default;
- // The id of the node that is the parent of this one, or its own
- // id if it is a root. parent <= 0 indicates that this member is invalid.
- int parent = -1;
-
- // A proxy for the depth of the tree that is used to prefer
- // connecting smaller trees to larger trees when merging disjoint
- // sets.
- int rank = 0;
-
- // The intersection of all device types supported by this node,
- // and those of all of its children, in priority order
- // of the preferred device.
- PrioritizedDeviceTypeVector supported_device_types;
-
- // The merged form of the device requested for this node, with
- // those of all of its children.
- DeviceNameUtils::ParsedName device_name;
-
- // Once colocation groups have been formed and we assigned at least
- // one node in this group to a device, assigned_device_name_index will
- // contain this device name's index in the graph. The `device_name` will
- // contain the parsed name of this device and `possible_devices`, if
- // computed, will contain just this device.
- // `assigned_device_name_index` is an optimization to avoid parsing and
- // comparing device names. The value of -1 signals that a single device
- // has not been chosen yet.
- int assigned_device_name_index = -1;
-
- // If this node is a root, stores a list of Devices to which this node
- // and all of its children have been assigned, or nullptr if this
- // has not yet been computed.
- std::vector<Device*> possible_devices;
- };
+ string DebugString() {
+ std::unordered_set<int> roots;
+ std::vector<string> root_strings;
+ for (const Node* node : graph_->nodes()) {
+ if (!node->IsOp()) {
+ continue;
+ }
+ int node_root = FindRoot(node->id());
+ if (roots.count(node_root) == 0) {
+ root_strings.push_back(DebugInfo(node_root));
+ roots.insert(node_root);
+ }
+ }
+ return absl::StrJoin(root_strings, "\n");
+ }
// Returns debugging info for the node referred to by 'node_root'.
string DebugInfo(const int node_root) {
@@ -667,7 +967,7 @@
colocation_nodes.push_back(node);
const string& op_type = node->type_string();
string devices_registered;
- for (const auto& device_type : members_[id].supported_device_types) {
+ for (const auto& device_type : members_[id].supported_device_types()) {
strings::StrAppend(&devices_registered,
DeviceTypeString(device_type.first), " ");
}
@@ -686,53 +986,52 @@
}
strings::StrAppend(&text, "\n");
- if (num_nodes_found <= 1) {
+ if (num_nodes_found <= 0) {
text.clear();
}
return text;
}
+ Status InitializeMemberWithAssignedDevice(const string& assigned_device_name,
+ const string& node_type,
+ bool must_be_full_name,
+ Member* member) {
+ // This node has already been assigned to a device, so we
+ // respect this placement, after sanity-checking it.
+ // NOTE: Since any assignment must have been performed by
+ // the TensorFlow runtime, we consider errors in this branch to
+ // be INTERNAL.
+ TF_RETURN_IF_ERROR(member->SetAssignedDeviceName(assigned_device_name));
+ if (!must_be_full_name) {
+ return Status::OK();
+ }
+ // Since assigned device must be a full specification, do extra checks.
+ const Device* assigned_device =
+ device_set_->FindDeviceByName(assigned_device_name);
+ if (assigned_device == nullptr) {
+ return errors::Internal("Assigned device '", assigned_device_name,
+ "' does not match any device");
+ }
+
+ for (const auto& d : member->supported_device_types()) {
+ if (DeviceType(assigned_device->attributes().device_type()) == d.first) {
+ return Status::OK();
+ }
+ }
+
+ return errors::Internal("Assigned device '", assigned_device_name,
+ "' does not have registered OpKernel support "
+ "for ",
+ node_type);
+ }
+
Status InitializeMember(const Node& node, Member* member) {
- const int id = node.id();
- DCHECK_GE(id, 0);
- member->parent = id;
- TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode(
- device_types_, node.def(), &member->supported_device_types));
+ TF_RETURN_IF_ERROR(
+ member->SetParentAndSupportedDevices(node, device_types_));
if (node.has_assigned_device_name()) {
- // This node has already been assigned to a device, so we
- // respect this placement, after sanity-checking it. The
- // device_name and supported_device_types for this node reflect
- // the assigned device, so any nodes colocated with this node
- // will be assigned to the same device (assuming this is
- // possible).
- // NOTE: Since any assignment must have been performed by
- // the TensorFlow runtime, we consider errors in this branch to
- // be INTERNAL.
- const string& assigned_device_name = node.assigned_device_name();
- if (!DeviceNameUtils::ParseFullName(assigned_device_name,
- &member->device_name)) {
- return errors::Internal("Malformed assigned device '",
- assigned_device_name, "'");
- }
- const Device* assigned_device =
- device_set_->FindDeviceByName(assigned_device_name);
- if (assigned_device == nullptr) {
- return errors::Internal("Assigned device '", assigned_device_name,
- "' does not match any device");
- }
-
- for (const auto& d : member->supported_device_types) {
- if (DeviceType(assigned_device->attributes().device_type()) ==
- d.first) {
- return Status::OK();
- }
- }
-
- return errors::Internal("Assigned device '", assigned_device_name,
- "' does not have registered OpKernel support "
- "for ",
- node.type_string());
+ TF_RETURN_IF_ERROR(InitializeMemberWithAssignedDevice(
+ node.assigned_device_name(), node.type_string(), true, member));
} else {
// This node has not yet been assigned to a device, so we
// calculate any constraints due to the set of registered
@@ -740,7 +1039,7 @@
// in the NodeDef.
// If no kernels are registered for this op type, fail with an error.
- if (member->supported_device_types.empty()) {
+ if (member->supported_device_types().empty()) {
std::set<string> registered_device_types;
for (Device* d : device_set_->devices()) {
registered_device_types.insert(d->device_type());
@@ -766,41 +1065,24 @@
// If the NodeDef contains a device, then we interpret it as a
// (partial) device specification.
if (!node.requested_device().empty()) {
- // The user has specified a device in the NodeDef, try to find a
- // valid device matching their specification in the set of
- // devices.
- // NOTE: The full name may specify a device that is not in
- // n.supported_device_types(), but we check that in AssignDevice().
- if (!DeviceNameUtils::ParseFullName(node.requested_device(),
- &member->device_name)) {
- return errors::InvalidArgument("Malformed device specification '",
- node.requested_device(), "'");
+ if (IsResourceGeneratorNode(node)) {
+ // Treat requested device on resource generating nodes as assigned
+ // device so that we don't override it.
+ TF_RETURN_IF_ERROR(InitializeMemberWithAssignedDevice(
+ node.requested_device(), node.type_string(), false, member));
+ } else {
+ // The user has specified a device in the NodeDef, try to find a
+ // valid device matching their specification in the set of
+ // devices.
+ // NOTE: The full name may specify a device that is not in
+ // n.supported_device_types(), but we check that in AssignDevice().
+ TF_RETURN_IF_ERROR(member->SetRequestedDeviceName(node));
}
}
}
return Status::OK();
}
- static bool HasPriorities(const PrioritizedDeviceTypeVector& device_types) {
- for (const auto& prioritized_device_type : device_types) {
- if (prioritized_device_type.second != 0) return true;
- }
- return false;
- }
-
- static bool ArePrioritiesSame(const PrioritizedDeviceTypeVector& a_types,
- const PrioritizedDeviceTypeVector& b_types) {
- if (a_types.size() != b_types.size()) {
- return false;
- }
- for (int i = 0; i < a_types.size(); ++i) {
- if (a_types[i].first != b_types[i].first) {
- return false;
- }
- }
- return true;
- }
-
// Updates target to contain the intersection of the device types in
// "target" and "other".
static void MergeSupportedDevices(PrioritizedDeviceTypeVector* target,
@@ -883,53 +1165,7 @@
// Returns the root node of the disjoint tree to which the node with the
// given id is connected.
- int FindRoot(int node_id) {
- Member& member = members_[node_id];
- DCHECK_GE(member.parent, 0);
- if (member.parent == node_id) {
- // member.parent is the root of this disjoint tree. Do nothing.
- } else {
- member.parent = FindRoot(member.parent);
- }
- // Now it is guaranteed that member.parent is the root of this disjoint
- // tree.
- DCHECK_GE(member.parent, 0);
- return member.parent;
- }
-
- // Ensures that the devices of 'dst's resource and reference match the device
- // specified for 'src', which is an input of 'dst' with a partially or fully
- // specified device.
- Status VerifyResourceAndRefInputsCanBeColocated(
- const Node* dst, const Node* src,
- const DeviceNameUtils::ParsedName& src_parsed_name) {
- std::vector<const Edge*> edges;
- TF_RETURN_IF_ERROR(dst->input_edges(&edges));
- for (const Edge* edge : edges) {
- DataType input_type = dst->input_type(edge->dst_input());
- if (input_type == DT_RESOURCE || IsRefType(input_type)) {
- const Node* input_node = edge->src();
- if (input_node == src) {
- continue;
- }
- const auto& input_root = members_[FindRoot(input_node->id())];
- const auto& input_parsed_name = input_root.device_name;
- if (DeviceNameUtils::HasSomeDetails(input_parsed_name) &&
- !DeviceNameUtils::AreCompatibleDevNames(input_parsed_name,
- src_parsed_name)) {
- return AttachDef(
- errors::InvalidArgument(
- "Could not colocate node with its "
- "resource and reference inputs; devices ",
- DeviceNameUtils::ParsedNameToString(input_parsed_name),
- " and ", DeviceNameUtils::ParsedNameToString(src_parsed_name),
- " are not compatible."),
- *dst);
- }
- }
- }
- return Status::OK();
- }
+ int FindRoot(int node_id) { return Member::FindRoot(&members_, node_id); }
const Graph* const graph_; // Not owned.
std::vector<Member> members_;
@@ -984,6 +1220,15 @@
return errors::FailedPrecondition("No devices are registered");
}
+ if (VLOG_IS_ON(3)) {
+ DumpGraphToFile("placer_input", *graph_, nullptr, "/tmp");
+ for (const Node* node : graph_->op_nodes()) {
+ VLOG(3) << " " << node->name() << ": requested: '"
+ << node->requested_device() << "' assigned: '"
+ << node->assigned_device_name() << "'";
+ }
+ }
+
ColocationGraph colocation_graph(
graph_, devices_, default_device_,
options_ == nullptr || options_->config.allow_soft_placement(),
@@ -991,14 +1236,15 @@
TF_RETURN_IF_ERROR(colocation_graph.Initialize());
- // For each node, assign a device based on the constraints in the
- // disjoint node set.
+ // For each node, assign a device based on the constraints in the disjoint
+ // node set.
std::vector<Node*> second_pass;
for (Node* node : graph_->op_nodes()) {
// The graph may have come pre-populated by the framework with assigned
// devices (e.g., for stateful placements), so the placer should not try to
// place nodes that are already placed.
if (node->has_assigned_device_name()) {
+ TF_RETURN_IF_ERROR(colocation_graph.LimitToAssignedDevice(*node));
LogDeviceAssignment(node, log_device_placement_);
continue;
}
@@ -1014,7 +1260,7 @@
continue;
}
- std::vector<Device*>* devices;
+ const std::vector<Device*>* devices;
Status status = colocation_graph.GetDevicesForNode(node, &devices);
if (!status.ok()) {
return AttachDef(
@@ -1062,7 +1308,7 @@
// Perform a second pass assignment for those nodes explicitly
// skipped during the first pass.
for (Node* node : second_pass) {
- std::vector<Device*>* devices;
+ const std::vector<Device*>* devices;
Status status = colocation_graph.GetDevicesForNode(node, &devices);
if (!status.ok()) {
return AttachDef(
@@ -1099,6 +1345,9 @@
log_device_placement_));
}
+ if (VLOG_IS_ON(3)) {
+ DumpGraphToFile("placer_output", *graph_, nullptr, "/tmp");
+ }
return Status::OK();
}
diff --git a/tensorflow/core/common_runtime/placer_test.cc b/tensorflow/core/common_runtime/placer_test.cc
index 04e77e5..3cc5ad8 100644
--- a/tensorflow/core/common_runtime/placer_test.cc
+++ b/tensorflow/core/common_runtime/placer_test.cc
@@ -17,6 +17,7 @@
#include <memory>
#include <string>
+#include <unordered_set>
#include <utility>
#include <vector>
@@ -24,11 +25,15 @@
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/graph_def_builder_util.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
@@ -40,6 +45,16 @@
namespace tensorflow {
+using ::tensorflow::test::function::GDef;
+using ::tensorflow::test::function::NDef;
+using FDH = ::tensorflow::FunctionDefHelper;
+
+constexpr char kCPU[] = "/device:fakecpu:0";
+constexpr char kGPU[] = "/device:fakegpu:0";
+
+constexpr char kFullCPU[] = "/job:a/replica:0/task:0/device:fakecpu:0";
+constexpr char kFullGPU[] = "/job:a/replica:0/task:0/device:fakegpu:0";
+
namespace {
////////////////////////////////////////////////////////////////////////////////
@@ -210,6 +225,16 @@
return Status::OK();
}
+ Status BuildGraph(const GraphDef& graph_def, Graph* out_graph) {
+ GraphConstructorOptions opts;
+ TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, graph_def, out_graph));
+ nodes_by_name_.clear();
+ for (Node* node : out_graph->nodes()) {
+ nodes_by_name_[node->name()] = node->id();
+ }
+ return Status::OK();
+ }
+
// Invokes the Placer on "graph". If no DeviceSet is specified, the
// placement will use the default DeviceSet (of 10 CPU and 10 GPU devices).
//
@@ -866,7 +891,7 @@
}
TEST_F(PlacerTest, TestResourceHandlesOnDifferentDevicesFails) {
- auto handle_test = [this](bool allow_soft_placement) {
+ auto handle_test = [this](bool allow_soft_placement, bool set_assigned) {
Graph g(OpRegistry::Global());
{ // Scope for temporary variables used to construct g.
GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
@@ -878,27 +903,41 @@
b.opts().WithName("two_handles_in"));
TF_EXPECT_OK(BuildGraph(b, &g));
- GetNodeByName(g, "var_cpu")
- ->set_assigned_device_name(
- "/job:a/replica:0/task:0/device:fakecpu:0");
- GetNodeByName(g, "var_gpu")
- ->set_assigned_device_name(
- "/job:a/replica:0/task:0/device:fakegpu:0");
+ if (set_assigned) {
+ GetNodeByName(g, "var_cpu")
+ ->set_assigned_device_name(
+ "/job:a/replica:0/task:0/device:fakecpu:0");
+ GetNodeByName(g, "var_gpu")
+ ->set_assigned_device_name(
+ "/job:a/replica:0/task:0/device:fakegpu:0");
+ } else {
+ GetNodeByName(g, "var_cpu")
+ ->set_requested_device("/job:a/replica:0/task:0/device:fakecpu:0");
+ GetNodeByName(g, "var_gpu")
+ ->set_requested_device("/job:a/replica:0/task:0/device:fakegpu:0");
+ }
}
SessionOptions options;
options.config.set_allow_soft_placement(allow_soft_placement);
options.config.set_log_device_placement(true);
Status s = Place(&g, &options);
- EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code()) << s.ToString();
EXPECT_TRUE(str_util::StrContains(
s.error_message(),
- "Could not colocate node with its resource and reference inputs"));
+ "Cannot place the graph because a reference or resource edge "
+ "connects "
+ "colocation groups with incompatible assigned devices: "
+ "/job:a/replica:0/task:0/device:fakegpu:0 vs "
+ "/job:a/replica:0/task:0/device:fakecpu:0"));
+
return Status::OK();
};
- TF_EXPECT_OK(handle_test(false));
- TF_EXPECT_OK(handle_test(true));
+ TF_EXPECT_OK(handle_test(false, false));
+ TF_EXPECT_OK(handle_test(false, true));
+ TF_EXPECT_OK(handle_test(true, false));
+ TF_EXPECT_OK(handle_test(true, true));
}
// Test that an assignment of an operator to the wrong device
@@ -1617,5 +1656,127 @@
EXPECT_DEVICE_TYPE(g, "in", "FakeGPU");
}
+REGISTER_KERNEL_BUILDER(Name("_Arg").Device("FakeCPU"), DummyOp);
+REGISTER_KERNEL_BUILDER(Name("_Arg").Device("FakeGPU"), DummyOp);
+REGISTER_KERNEL_BUILDER(Name("_Retval").Device("FakeCPU"), DummyOp);
+REGISTER_KERNEL_BUILDER(Name("_Retval").Device("FakeGPU"), DummyOp);
+REGISTER_KERNEL_BUILDER(Name("Identity").Device("FakeCPU"), DummyOp);
+REGISTER_KERNEL_BUILDER(Name("Identity").Device("FakeGPU"), DummyOp);
+REGISTER_KERNEL_BUILDER(Name("Const").Device("FakeCPU"), DummyOp);
+REGISTER_KERNEL_BUILDER(Name("Const").Device("FakeGPU"), DummyOp);
+REGISTER_KERNEL_BUILDER(Name("Mul").Device("FakeCPU"), DummyOp);
+REGISTER_KERNEL_BUILDER(Name("Mul").Device("FakeGPU"), DummyOp);
+REGISTER_KERNEL_BUILDER(Name("Add").Device("FakeCPU"), DummyOp);
+REGISTER_KERNEL_BUILDER(Name("Add").Device("FakeGPU"), DummyOp);
+
+TEST_F(PlacerTest, RequestedDeviceOnResourceGeneratorIsTreatedAsAssigned) {
+ /*
+ * a:RES:GPU b:RES:CPU
+ * | |
+ * | |
+ * v v
+ * id1 id2
+ * @loc:id2
+ */
+ FunctionDef func = test::function::ResourceOutput();
+ GraphDef graph = GDef(
+ {
+ NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}, kGPU),
+ NDef("b", "_Arg", {}, {{"T", DT_RESOURCE}}, kCPU),
+ NDef("id1", "Identity", {"a"},
+ {{"T", DT_RESOURCE},
+ {"_class", gtl::ArraySlice<string>({"loc:@id2"})}}),
+ NDef("id2", "Identity", {"b"}, {{"T", DT_RESOURCE}}),
+ },
+ // FunctionLib
+ {func});
+
+ Graph g(OpRegistry::Global());
+ TF_ASSERT_OK(BuildGraph(graph, &g));
+ Status s = Place(&g);
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_TRUE(str_util::StrContains(
+ s.error_message(),
+ "Cannot place the graph because a reference or resource edge connects "
+ "colocation groups with incompatible assigned devices:"));
+}
+
+TEST_F(PlacerTest, RequestedDeviceCanBeOverridden) {
+ /*
+ * a:RES b:RES
+ * | |
+ * id_a:GPU id_b:CPU
+ * | |
+ * v v
+ * id1 id2
+ * @loc:id2
+ */
+ FunctionDef func = test::function::ResourceOutput();
+ GraphDef graph = GDef(
+ {
+ NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}),
+ NDef("b", "_Arg", {}, {{"T", DT_RESOURCE}}),
+ NDef("id_a", "Identity", {"a"}, {{"T", DT_RESOURCE}}, kGPU),
+ NDef("id_b", "Identity", {"b"}, {{"T", DT_RESOURCE}}, kCPU),
+ NDef("id1", "Identity", {"id_a"},
+ {{"T", DT_RESOURCE},
+ {"_class", gtl::ArraySlice<string>({"loc:@id2"})}}),
+ NDef("id2", "Identity", {"id_b"}, {{"T", DT_RESOURCE}}),
+ },
+ // FunctionLib
+ {func});
+
+ Graph g(OpRegistry::Global());
+ TF_ASSERT_OK(BuildGraph(graph, &g));
+ TF_ASSERT_OK(Place(&g));
+
+ // All should be colocated
+ EXPECT_COLOCATED(g, "a", "b");
+ EXPECT_COLOCATED(g, "id_a", "id_b");
+ EXPECT_COLOCATED(g, "id1", "id2");
+ EXPECT_COLOCATED(g, "a", "id_a");
+ EXPECT_COLOCATED(g, "a", "id1");
+}
+
+TEST_F(PlacerTest, AssignedDevicesAreNotOverriddenDueToResourcesAndColocation) {
+ /*
+ * a:RES b:RES
+ * | |
+ * id_a:GPU id_b:CPU
+ * | |
+ * v v
+ * id1 id2
+ * @loc:id2
+ */
+ FunctionDef func = test::function::ResourceOutput();
+ GraphDef graph = GDef(
+ {
+ NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}),
+ NDef("b", "_Arg", {}, {{"T", DT_RESOURCE}}),
+ NDef("id_a", "Identity", {"a"}, {{"T", DT_RESOURCE}}),
+ NDef("id_b", "Identity", {"b"}, {{"T", DT_RESOURCE}}),
+ NDef("id1", "Identity", {"id_a"},
+ {{"T", DT_RESOURCE},
+ {"_class", gtl::ArraySlice<string>({"loc:@id2"})}}),
+ NDef("id2", "Identity", {"id_b"}, {{"T", DT_RESOURCE}}),
+ },
+ // FunctionLib
+ {func});
+
+ Graph g(OpRegistry::Global());
+ TF_ASSERT_OK(BuildGraph(graph, &g));
+ std::unordered_map<string, Node*> nodes = g.BuildNodeNameIndex();
+ GetNodeByName(g, "id_a")->set_assigned_device_name(kFullGPU);
+ GetNodeByName(g, "id_b")->set_assigned_device_name(kFullCPU);
+ Status s = Place(&g);
+ EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
+ EXPECT_TRUE(str_util::StrContains(
+ s.error_message(),
+ "Cannot place the graph because a reference or resource edge connects "
+ "colocation groups with incompatible assigned devices: "
+ "/job:a/replica:0/task:0/device:fakecpu:0 vs "
+ "/job:a/replica:0/task:0/device:fakegpu:0"));
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/shared_counter.h b/tensorflow/core/common_runtime/shared_counter.h
new file mode 100644
index 0000000..5e37852
--- /dev/null
+++ b/tensorflow/core/common_runtime/shared_counter.h
@@ -0,0 +1,31 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SHARED_COUNTER_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_SHARED_COUNTER_H_
+
+namespace tensorflow {
+// A lightweight thread-safe monotone counter for establishing
+// temporal ordering.
+class SharedCounter {
+ public:
+ int64 get() { return value_; }
+ int64 next() { return ++value_; }
+
+ private:
+ std::atomic<int64> value_{0};
+};
+
+} // namespace tensorflow
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SHARED_COUNTER_H_
diff --git a/tensorflow/core/common_runtime/step_stats_collector.cc b/tensorflow/core/common_runtime/step_stats_collector.cc
index 4926544..1bdccf5 100644
--- a/tensorflow/core/common_runtime/step_stats_collector.cc
+++ b/tensorflow/core/common_runtime/step_stats_collector.cc
@@ -409,6 +409,21 @@
}
}
+void StepStatsCollector::SaveThreadName(const string& device,
+ const uint32 thread_id,
+ const string& thread_name) {
+ VLOG(1) << "Save dev " << device << " thread id " << thread_id << " name "
+ << thread_name;
+ {
+ mutex_lock l(mu_);
+ if (finalized_) {
+ LOG(WARNING) << "thread_name saved after finalize will not be collected.";
+ }
+ auto& thread_names_map = thread_names_[device];
+ thread_names_map[thread_id] = thread_name;
+ }
+}
+
NodeExecStatsInterface* StepStatsCollector::CreateNodeExecStats(
const Node* node) {
// Only collect statistics for non-transfer nodes.
@@ -531,5 +546,15 @@
stats->stats()->Swap(dss->add_node_stats());
}
}
+ for (const auto& device_thread : thread_names_) {
+ if (dev_stats_pb.find(device_thread.first) == dev_stats_pb.end()) {
+ // skip device without DeviceStepStats.
+ continue;
+ }
+ DeviceStepStats* dss = dev_stats_pb.at(device_thread.first);
+ for (const auto& thread_name : device_thread.second) {
+ (*dss->mutable_thread_names())[thread_name.first] = thread_name.second;
+ }
+ }
}
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/step_stats_collector.h b/tensorflow/core/common_runtime/step_stats_collector.h
index 7d34383..dfcc51f 100644
--- a/tensorflow/core/common_runtime/step_stats_collector.h
+++ b/tensorflow/core/common_runtime/step_stats_collector.h
@@ -175,6 +175,10 @@
void Save(const string& device, NodeExecStats* node_stats_pb);
void Save(const string& device, NodeExecStatsWrapper* node_stats);
+ // Saves thread name.
+ void SaveThreadName(const string& device, const uint32 thread_id,
+ const string& thread_name);
+
NodeExecStatsInterface* CreateNodeExecStats(const Node* node) override;
string ReportAllocsOnResourceExhausted(const string& err) override;
@@ -191,12 +195,14 @@
static const uint64 kMaxCollectedNodes = 1 << 20;
typedef std::vector<std::unique_ptr<NodeExecStatsWrapper>> NodeStatsVector;
+ typedef std::unordered_map<uint32, string> ThreadNamesMap;
void FinalizeInternal() EXCLUSIVE_LOCKS_REQUIRED(mu_);
mutex mu_;
bool finalized_ GUARDED_BY(mu_);
std::unordered_map<string, NodeStatsVector> dev_stats_ GUARDED_BY(mu_);
+ std::unordered_map<string, ThreadNamesMap> thread_names_ GUARDED_BY(mu_);
StepStats* step_stats_ GUARDED_BY(mu_);
uint64 collected_nodes_ GUARDED_BY(mu_) = 0;
};
diff --git a/tensorflow/core/debug/BUILD b/tensorflow/core/debug/BUILD
index e1961b8..f8c07dd 100644
--- a/tensorflow/core/debug/BUILD
+++ b/tensorflow/core/debug/BUILD
@@ -229,7 +229,9 @@
"//tensorflow/core:master_proto_cc",
"//tensorflow/core:math_ops_op_lib",
"//tensorflow/core:nn_ops_op_lib",
+ "//tensorflow/core:no_op_op_lib",
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:sendrecv_ops_op_lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
diff --git a/tensorflow/core/debug/debug_graph_utils.cc b/tensorflow/core/debug/debug_graph_utils.cc
index 5fc95a8..b69eb1d 100644
--- a/tensorflow/core/debug/debug_graph_utils.cc
+++ b/tensorflow/core/debug/debug_graph_utils.cc
@@ -299,7 +299,7 @@
auto builder = NodeDefBuilder(copy_node_name, copy_op_name)
.Input(src_node_name, src_output, src_dt)
- .Attr("debug_ops_spec", std::move(debug_ops_spec));
+ .Attr("debug_ops_spec", debug_ops_spec);
if (!builder.Finalize(&node_def).ok()) {
return Status(
diff --git a/tensorflow/core/debug/debug_grpc_testlib.cc b/tensorflow/core/debug/debug_grpc_testlib.cc
index f70931e..4927caf 100644
--- a/tensorflow/core/debug/debug_grpc_testlib.cc
+++ b/tensorflow/core/debug/debug_grpc_testlib.cc
@@ -18,6 +18,7 @@
#include "tensorflow/core/debug/debug_graph_utils.h"
#include "tensorflow/core/debug/debugger_event_metadata.pb.h"
#include "tensorflow/core/framework/summary.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
diff --git a/tensorflow/core/debug/debug_io_utils.cc b/tensorflow/core/debug/debug_io_utils.cc
index 6994dec..ebcb046 100644
--- a/tensorflow/core/debug/debug_io_utils.cc
+++ b/tensorflow/core/debug/debug_io_utils.cc
@@ -35,6 +35,7 @@
#include "tensorflow/core/debug/debugger_event_metadata.pb.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/summary.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/lib/hash/hash.h"
@@ -730,7 +731,7 @@
::grpc::ChannelArguments args;
args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits<int32>::max());
// Avoid problems where default reconnect backoff is too long (e.g., 20 s).
- args.SetInt("grpc.testing.fixed_reconnect_backoff_ms", 1000);
+ args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 1000);
channel_ = ::grpc::CreateCustomChannel(
server_stream_addr_, ::grpc::InsecureChannelCredentials(), args);
if (!channel_->WaitForConnected(
diff --git a/tensorflow/core/debug/debug_io_utils_test.cc b/tensorflow/core/debug/debug_io_utils_test.cc
index 82e0ae5..0926a82 100644
--- a/tensorflow/core/debug/debug_io_utils_test.cc
+++ b/tensorflow/core/debug/debug_io_utils_test.cc
@@ -22,6 +22,7 @@
#include "tensorflow/core/debug/debug_node_key.h"
#include "tensorflow/core/debug/debugger_event_metadata.pb.h"
#include "tensorflow/core/framework/summary.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/status_test_util.h"
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD
index 351dbeb..da88f9c 100644
--- a/tensorflow/core/distributed_runtime/BUILD
+++ b/tensorflow/core/distributed_runtime/BUILD
@@ -591,7 +591,9 @@
"//tensorflow/core:core_cpu_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
+ "//tensorflow/core:no_op_op_lib",
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:sendrecv_ops_op_lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
@@ -658,6 +660,7 @@
"//tensorflow/core:lib_internal",
"//tensorflow/core:master_proto_cc",
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:state_ops_op_lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
diff --git a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc
index 9f94a24..443759a 100644
--- a/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc
+++ b/tensorflow/core/distributed_runtime/collective_param_resolver_distributed.cc
@@ -82,15 +82,21 @@
const ConfigProto& config, const DeviceMgr* dev_mgr,
DeviceResolverDistributed* dev_resolver, WorkerCacheInterface* worker_cache,
const string& task_name)
- : CollectiveParamResolverLocal(dev_mgr, dev_resolver, task_name),
+ : CollectiveParamResolverLocal(config, dev_mgr, dev_resolver, task_name),
worker_cache_(worker_cache),
group_leader_(task_name == config.experimental().collective_group_leader()
? ""
- : config.experimental().collective_group_leader()) {}
+ : config.experimental().collective_group_leader()) {
+ VLOG(1) << "CompleteParamResolverDistributed ctor task={" << task_name
+ << "} config.collective_group_leader={"
+ << config.experimental().collective_group_leader() << "}";
+}
void CollectiveParamResolverDistributed::CompleteParamsAsync(
const string& device, CollectiveParams* cp, CancellationManager* cancel_mgr,
const StatusCallback& done) {
+ VLOG(1) << "CompleteParams distributed " << device << " for " << cp << ": "
+ << cp->ToString();
CompleteGroupDistributed(device, cp, cancel_mgr,
[this, device, cp, cancel_mgr, done](
const Status& s, const GroupRec* gr) {
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index 48b72fb..2f14967 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -1103,6 +1103,8 @@
req.options().experimental().collective_graph_key();
if (config.experimental().collective_deterministic_sequential_execution()) {
opts->collective_order = GraphCollectiveOrder::kEdges;
+ } else if (config.experimental().collective_nccl()) {
+ opts->collective_order = GraphCollectiveOrder::kAttrs;
}
}
diff --git a/tensorflow/core/distributed_runtime/rpc/BUILD b/tensorflow/core/distributed_runtime/rpc/BUILD
index a081ec7..6079634 100644
--- a/tensorflow/core/distributed_runtime/rpc/BUILD
+++ b/tensorflow/core/distributed_runtime/rpc/BUILD
@@ -159,6 +159,17 @@
],
)
+cc_library(
+ name = "grpc_response_cache",
+ srcs = ["grpc_response_cache.cc"],
+ hdrs = ["grpc_response_cache.h"],
+ deps = [
+ ":grpc_util",
+ "//tensorflow/core:lib",
+ "@com_google_absl//absl/time",
+ ],
+)
+
tf_cuda_library(
name = "grpc_worker_service",
srcs = ["grpc_worker_service.cc"],
@@ -166,6 +177,7 @@
deps = [
":async_service_interface",
":grpc_call",
+ ":grpc_response_cache",
":grpc_tensor_coding",
":grpc_util",
":grpc_worker_service_impl",
@@ -183,6 +195,8 @@
"//tensorflow/core/distributed_runtime:worker_cache",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/distributed_runtime:worker_session",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/time",
],
)
@@ -313,11 +327,15 @@
":grpc_server_lib",
"//tensorflow:grpc++",
"//tensorflow/core:core_cpu",
+ "//tensorflow/core:data_flow_ops_op_lib",
"//tensorflow/core:framework_internal",
"//tensorflow/core:functional_ops_op_lib",
"//tensorflow/core:lib",
+ "//tensorflow/core:lookup_ops_op_lib",
"//tensorflow/core:math_ops_op_lib",
+ "//tensorflow/core:no_op_op_lib",
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:sendrecv_ops_op_lib",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/kernels:data_flow",
],
@@ -339,6 +357,7 @@
"//tensorflow/core:lib",
"//tensorflow/core:nn_ops_op_lib",
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:state_ops_op_lib",
"//tensorflow/core:testlib",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/kernels:constant_op",
@@ -478,6 +497,7 @@
"//tensorflow/core:lib",
"//tensorflow/core:master_proto_cc",
"//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:state_ops_op_lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc b/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc
index 781b7d6..64c2218 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc
@@ -62,7 +62,7 @@
args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, std::numeric_limits<int32>::max());
// NOTE(mrry): Some versions of gRPC use a 20-second minimum backoff
// on connection failure, which makes our tests time out.
- args.SetInt("grpc.testing.fixed_reconnect_backoff_ms", 1000);
+ args.SetInt(GRPC_ARG_MAX_RECONNECT_BACKOFF_MS, 1000);
if (rpc_options != nullptr) {
if (rpc_options->compression_algorithm() == "deflate") {
args.SetCompressionAlgorithm(GRPC_COMPRESS_DEFLATE);
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc
index 2daefcb..2479e73 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc
@@ -39,6 +39,8 @@
namespace tensorflow {
+const int kMaxWorkerRpcRetries = 10;
+
class GrpcRemoteWorker : public WorkerInterface {
public:
explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel,
@@ -259,17 +261,19 @@
// given callback, `done`, will be called when the RPC completes.
void IssueRequest(const protobuf::Message* request,
protobuf::Message* response, const ::grpc::string& method,
- StatusCallback done, CallOptions* call_opts = nullptr) {
+ StatusCallback done, CallOptions* call_opts = nullptr,
+ int max_retries = kMaxWorkerRpcRetries) {
new RPCState<protobuf::Message>(&stub_, cq_, method, *request, response,
std::move(done), call_opts,
- callback_threadpool_);
+ callback_threadpool_, max_retries);
}
void IssueRequest(const protobuf::Message* request, TensorResponse* response,
const ::grpc::string& method, StatusCallback done,
- CallOptions* call_opts = nullptr) {
+ CallOptions* call_opts = nullptr,
+ int max_retries = kMaxWorkerRpcRetries) {
new RPCState<TensorResponse>(&stub_, cq_, method, *request, response,
std::move(done), call_opts,
- callback_threadpool_);
+ callback_threadpool_, max_retries);
}
// Helper function for initializing the RpcMethod objects below.
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_response_cache.cc b/tensorflow/core/distributed_runtime/rpc/grpc_response_cache.cc
new file mode 100644
index 0000000..613c290
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_response_cache.cc
@@ -0,0 +1,183 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/distributed_runtime/rpc/grpc_response_cache.h"
+#include "tensorflow/core/platform/env.h"
+
+namespace tensorflow {
+
+struct WorkerCacheEntry {
+ enum class State {
+ PENDING = 0,
+ ACTIVE = 1,
+ FINISHED = 2,
+ };
+
+ State state = State::PENDING;
+ int64 expires_seconds;
+
+ ::grpc::ByteBuffer response_buf;
+ Status response_status;
+
+ // Additional retries may arrive while a request is still executing. The
+ // callbacks for these calls are queued in `callbacks` and evaluated after
+ // the original request is completed.
+ std::vector<std::pair<RPCResponse, StatusCallback>> callbacks;
+};
+
+void RPCResponse::Encode(::grpc::ByteBuffer* tgt) const {
+ if (buf_ != nullptr) {
+ *tgt = *buf_;
+ } else {
+ CHECK(msg_ != nullptr);
+ ::grpc::Slice slice(msg_->ByteSizeLong());
+ msg_->SerializeWithCachedSizesToArray(
+ const_cast<uint8*>(reinterpret_cast<const uint8*>(slice.begin())));
+ ::grpc::ByteBuffer tmp(&slice, 1);
+ tgt->Swap(&tmp);
+ }
+}
+
+void RPCResponse::CopyFrom(const ::grpc::ByteBuffer& src) {
+ if (buf_ != nullptr) {
+ *buf_ = src;
+ return;
+ }
+
+ CHECK(msg_ != nullptr);
+ // We create a single slice when encoding protocol messages.
+ std::vector<::grpc::Slice> slices;
+ if (src.Dump(&slices).ok()) {
+ msg_->ParseFromArray(slices[0].begin(), slices[0].size());
+ } else {
+ LOG(ERROR) << "Failed to decode cached buffer.";
+ }
+}
+
+void GrpcResponseCache::LookupOrCompute(const string& key, RPCResponse response,
+ ComputeFunc compute_func,
+ StatusCallback done_cb) {
+ VLOG(1) << "Lookup " << key;
+ std::shared_ptr<WorkerCacheEntry> req;
+ MaybeCleanup();
+ {
+ mutex_lock m(mu_);
+
+ if (requests_.find(key) != requests_.end()) {
+ req = requests_[key];
+ } else {
+ req.reset(new WorkerCacheEntry);
+ requests_[key] = req;
+ }
+
+ if (req->state == WorkerCacheEntry::State::FINISHED) {
+ if (req->expires_seconds > Env::Default()->NowSeconds()) {
+ VLOG(1) << "Reuse cached response for " << key;
+ response.CopyFrom(req->response_buf);
+ done_cb(req->response_status);
+ return;
+ }
+ VLOG(1) << "Found expired cache entry for " << key;
+ req->state = WorkerCacheEntry::State::PENDING;
+ req->response_buf.Clear();
+ }
+
+ req->callbacks.push_back(std::make_pair(response, done_cb));
+
+ if (req->state == WorkerCacheEntry::State::ACTIVE) {
+ VLOG(1) << "Found active request for " << key
+ << ". Adding entry to response queue.";
+ return;
+ }
+
+ VLOG(2) << "No cache entry for " << key << ", running user computation.";
+ req->state = WorkerCacheEntry::State::ACTIVE;
+ req->expires_seconds = Env::Default()->NowSeconds() + expire_time_seconds_;
+ }
+
+ compute_func([this, key, req, response](Status status) {
+ mutex_lock m(mu_);
+ response.Encode(&req->response_buf);
+ current_bytes_ += req->response_buf.Length();
+
+ req->response_status = status;
+ req->state = WorkerCacheEntry::State::FINISHED;
+
+ VLOG(1) << "Operation for " << key << " finished. "
+ << "Status: " << status << ", " << req->response_buf.Length()
+ << " response bytes, " << req->callbacks.size()
+ << " pending callbacks.";
+ for (auto& cb : req->callbacks) {
+ cb.first.CopyFrom(req->response_buf);
+ cb.second(req->response_status);
+ }
+ req->callbacks.clear();
+ });
+}
+
+// Remove all stale or expired cache entries if the cache is full.
+void GrpcResponseCache::MaybeCleanup() {
+ mutex_lock m(mu_);
+ if (current_bytes_ < max_bytes_) {
+ return;
+ }
+
+ VLOG(1) << "Cleanup: " << current_bytes_ << " -> " << max_bytes_;
+ std::vector<std::pair<string, std::shared_ptr<WorkerCacheEntry>>>
+ ordered_entries;
+ ordered_entries.reserve(requests_.size());
+ for (const auto& p : requests_) {
+ ordered_entries.push_back(std::make_pair(p.first, p.second));
+ }
+
+ std::sort(ordered_entries.begin(), ordered_entries.end(),
+ [](const std::pair<string, std::shared_ptr<WorkerCacheEntry>>& a,
+ const std::pair<string, std::shared_ptr<WorkerCacheEntry>>& b) {
+ return a.second->expires_seconds > b.second->expires_seconds;
+ });
+
+ std::unordered_map<string, std::shared_ptr<WorkerCacheEntry>> kept;
+ int64 now = Env::Default()->NowSeconds();
+ int64 bytes_used = 0;
+
+ // Always keep active requests.
+ for (auto& pair : ordered_entries) {
+ if (pair.second->state != WorkerCacheEntry::State::FINISHED) {
+ kept.insert(pair);
+ }
+ }
+
+ // Keep unexpired, finished requests up to half of max_bytes_. This reduces
+ // chances of overfilling the cache when active requests complete and
+ // amortizes cache cleanup cost.
+ for (auto& pair : ordered_entries) {
+ if (pair.second->expires_seconds < now || bytes_used >= max_bytes_ / 2) {
+ break;
+ }
+
+ if (pair.second->state == WorkerCacheEntry::State::FINISHED) {
+ kept.insert(pair);
+ bytes_used += pair.second->response_buf.Length();
+ }
+ }
+
+ VLOG(1) << "Cleaned cache. Bytes used: " << current_bytes_ << " -> "
+ << bytes_used << ". Cache size: " << requests_.size() << " -> "
+ << kept.size();
+ current_bytes_ = bytes_used;
+ std::swap(requests_, kept);
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_response_cache.h b/tensorflow/core/distributed_runtime/rpc/grpc_response_cache.h
new file mode 100644
index 0000000..0892d9f
--- /dev/null
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_response_cache.h
@@ -0,0 +1,91 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RESPONSE_CACHE_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RESPONSE_CACHE_H_
+
+#include <memory>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+// gRPC response caching. Most WorkerService methods cannot be retried directly
+// as they will fail or deadlock. To enable retrying, we can instead cache
+// responses for a short period of time and reply to duplicate requests from the
+// cache.
+namespace tensorflow {
+
+// Union type to aid caching of either raw buffers (for RecvTensor RPCs) and
+// protocol buffer messages (for all other RPCs).
+class RPCResponse {
+ public:
+ explicit RPCResponse() : buf_(nullptr), msg_(nullptr) {}
+ explicit RPCResponse(::grpc::ByteBuffer* b) : buf_(b), msg_(nullptr) {}
+ explicit RPCResponse(protobuf::Message* m) : buf_(nullptr), msg_(m) {}
+
+ // Encode this response into the target buffer.
+ void Encode(::grpc::ByteBuffer* tgt) const;
+
+ // Copy from `src`: if this is a buffer, make a shallow copy.
+ // For protocol messages, parse the response from `src`.
+ void CopyFrom(const ::grpc::ByteBuffer& src);
+
+ private:
+ ::grpc::ByteBuffer* buf_;
+ protobuf::Message* msg_;
+};
+
+typedef std::function<void(StatusCallback)> ComputeFunc;
+struct WorkerCacheEntry;
+
+// Track and cache the state of worker service RPCs. An RPC can be in 3 states:
+//
+// * PENDING: this is the first call of the RPC, and it will transition to
+// * ACTIVE: another thread is active processing this RPC
+// * FINISHED: the worker has finished processing the method
+//
+// The response from completed RPCs are LRU cached until either `max_bytes`
+// bytes are in use by the cache or they expire (according to `expire_time`).
+class GrpcResponseCache {
+ public:
+ GrpcResponseCache(int64 max_bytes, int64 expire_time_seconds)
+ : max_bytes_(max_bytes), expire_time_seconds_(expire_time_seconds) {}
+
+ // Lookup the result for key.
+ // If it is finished, invoke `done_cb` immediately after filling `response`.
+ // If active, done_db will be invoked when the current call completes.
+ // Otherwise, invoke `compute_func` to fill the cache and invoke done_cb.
+ void LookupOrCompute(const string& key, RPCResponse response,
+ ComputeFunc compute_func, StatusCallback done_cb);
+
+ // Remove all stale or expired cache entries if the cache is full.
+ void MaybeCleanup();
+
+ private:
+ int64 current_bytes_ GUARDED_BY(mu_) = 0;
+ const int64 max_bytes_;
+ const int64 expire_time_seconds_;
+
+ std::unordered_map<string, std::shared_ptr<WorkerCacheEntry>> requests_
+ GUARDED_BY(mu_);
+ mutex mu_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RESPONSE_CACHE_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc
index 7f63cc9..3635caf 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc
@@ -210,7 +210,7 @@
get_stub(index), &completion_queue_, *get_method_ptr(index),
call->request(), call->response(),
/*done=*/[call](const Status& s) { call->Done(s); }, call->call_opts(),
- nullptr /*threadpool*/, fail_fast_, timeout_in_ms_);
+ nullptr /*threadpool*/, fail_fast_, timeout_in_ms_, 0 /* max_retries */);
}
} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
index 1405c76..f087a39 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
@@ -431,7 +431,7 @@
ServiceInitFunction service_func = nullptr;
GrpcServerOptions options;
options.rendezvous_mgr_func = NewRpcRendezvousMgr;
- Status s = ret->Init();
+ Status s = ret->Init(options);
if (!s.ok()) {
LOG(ERROR) << s;
return s;
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_state.h b/tensorflow/core/distributed_runtime/rpc/grpc_state.h
index d736386..0ca64dc 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_state.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_state.h
@@ -32,6 +32,9 @@
namespace tensorflow {
// Object allocated per active RPC.
+// Manage the state of a single asynchronous RPC request. If `max_retries`
+// is greater than 0, the request will be retried for any transient failures
+// as long as the overall deadline has not elapsed.
template <class Response>
class RPCState : public GrpcClientCQTag {
public:
@@ -39,34 +42,55 @@
RPCState(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq,
const ::grpc::string& method, const protobuf::Message& request,
Response* response, StatusCallback done, CallOptions* call_opts,
- thread::ThreadPool* threadpool)
+ thread::ThreadPool* threadpool, int32 max_retries = 0)
: RPCState(stub, cq, method, request, response, std::move(done),
call_opts, threadpool, /*fail_fast=*/false,
- /*timeout_in_ms=*/0) {}
+ /*timeout_in_ms=*/0, max_retries) {}
template <typename Request>
RPCState(::grpc::GenericStub* stub, ::grpc::CompletionQueue* cq,
const ::grpc::string& method, const Request& request,
Response* response, StatusCallback done, CallOptions* call_opts,
- thread::ThreadPool* threadpool, bool fail_fast, int64 timeout_in_ms)
- : call_opts_(call_opts), threadpool_(threadpool), done_(std::move(done)) {
- context_.set_fail_fast(fail_fast);
- if (timeout_in_ms > 0) {
- context_.set_deadline(gpr_time_from_millis(timeout_in_ms, GPR_TIMESPAN));
- }
-
- if (call_opts) {
- call_opts->SetCancelCallback([this]() { context_.TryCancel(); });
- }
-
+ thread::ThreadPool* threadpool, bool fail_fast, int64 timeout_in_ms,
+ int32 max_retries)
+ : call_opts_(call_opts),
+ threadpool_(threadpool),
+ done_(std::move(done)),
+ cq_(cq),
+ stub_(stub),
+ method_(method),
+ max_retries_(max_retries),
+ timeout_in_ms_(timeout_in_ms),
+ fail_fast_(fail_fast) {
response_ = response;
::grpc::Status s = GrpcMaybeUnparseProto(request, &request_buf_);
if (!s.ok()) {
LOG(ERROR) << "GrpcMaybeUnparseProto returned with non-ok status: "
<< s.error_message();
+ // Skip retry logic if we fail to parse our request.
+ done_(FromGrpcStatus(s));
+ delete this;
+ return;
}
- call_ =
- std::move(stub->PrepareUnaryCall(&context_, method, request_buf_, cq));
+ StartCall();
+ }
+
+ void StartCall() {
+ context_.reset(new ::grpc::ClientContext());
+ context_->set_fail_fast(fail_fast_);
+
+ if (timeout_in_ms_ > 0) {
+ context_->set_deadline(
+ gpr_time_from_millis(timeout_in_ms_, GPR_TIMESPAN));
+ }
+ if (call_opts_) {
+ call_opts_->SetCancelCallback([this]() { context_->TryCancel(); });
+ }
+
+ VLOG(2) << "Starting call: " << method_;
+
+ call_ = std::move(
+ stub_->PrepareUnaryCall(context_.get(), method_, request_buf_, cq_));
call_->StartCall();
call_->Finish(&response_buf_, &status_, this);
}
@@ -89,16 +113,26 @@
threadpool_->Schedule([this]() { ParseAndCallDone(); });
} else {
ParseAndCallDone();
- return;
}
- } else {
- VLOG(2) << "Call returned with non-ok status: " << s;
+ return;
+ }
- // Attach additional GRPC error information if any
+ VLOG(1) << method_ << " returned with non-ok status: " << s
+ << " Retries: " << num_retries_ << " Max: " << max_retries_ << "\n"
+ << context_->debug_error_string();
+ // Retry if we have any attempts left
+ if (++num_retries_ <= max_retries_ &&
+ (errors::IsUnavailable(s) || errors::IsUnknown(s))) {
+ response_buf_.Clear();
+ VLOG(1) << "Retrying call for " << method_ << "Retry: " << num_retries_
+ << " of " << max_retries_;
+ StartCall();
+ } else {
+ // Attach additional GRPC error information if any to the final status
s = Status(s.code(),
strings::StrCat(s.error_message(),
"\nAdditional GRPC error information:\n",
- context_.debug_error_string()));
+ context_->debug_error_string()));
done_(s);
delete this;
}
@@ -115,7 +149,7 @@
private:
CallOptions* call_opts_;
- ::grpc::ClientContext context_;
+ std::unique_ptr<::grpc::ClientContext> context_;
thread::ThreadPool* threadpool_;
std::unique_ptr<::grpc::GenericClientAsyncResponseReader> call_;
Response* response_;
@@ -123,6 +157,15 @@
::grpc::ByteBuffer response_buf_;
::grpc::Status status_;
StatusCallback done_;
+ int64 timeout_in_ms_;
+
+ size_t num_retries_ = 0;
+ size_t max_retries_;
+
+ ::grpc::CompletionQueue* cq_;
+ ::grpc::GenericStub* stub_;
+ ::grpc::string method_;
+ bool fail_fast_;
};
} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
index 34bf629..9048621 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.cc
@@ -16,11 +16,14 @@
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h"
#include <deque>
+#include <memory>
#include <unordered_map>
+#include <vector>
#include "grpcpp/alarm.h"
#include "grpcpp/server_builder.h"
+#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/common_runtime/buf_rendezvous.h"
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
@@ -32,6 +35,7 @@
#include "tensorflow/core/distributed_runtime/rendezvous_mgr_interface.h"
#include "tensorflow/core/distributed_runtime/rpc/async_service_interface.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_call.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_response_cache.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_tensor_coding.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h"
@@ -42,8 +46,12 @@
#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/protobuf/transport_options.pb.h"
#include "tensorflow/core/protobuf/worker.pb.h"
@@ -91,10 +99,11 @@
public:
explicit GrpcWorkerServiceThread(
GrpcWorker* worker, ::grpc::ServerBuilder* builder,
- std::unordered_map<int, int> queue_depth,
+ std::unordered_map<int, int> queue_depth, GrpcResponseCache* cache,
grpc::WorkerService::AsyncService* worker_service)
: worker_(worker),
queue_depth_(queue_depth),
+ cache_(cache),
worker_service_(worker_service),
is_shutdown_(false) {
cq_ = builder->AddCompletionQueue();
@@ -220,18 +229,32 @@
NonOwnedProtoRunGraphResponse* wrapped_response =
new NonOwnedProtoRunGraphResponse(&call->response);
call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
- worker_->RunGraphAsync(call_opts, wrapped_request, wrapped_response,
- [call, call_opts, wrapped_request,
- wrapped_response](const Status& s) {
- if (!s.ok()) {
- VLOG(1) << "Bad response from RunGraph:" << s;
- }
- call->ClearCancelCallback();
- delete call_opts;
- delete wrapped_request;
- delete wrapped_response;
- call->SendResponse(ToGrpcStatus(s));
- });
+ auto done_cb = [call, call_opts, wrapped_request,
+ wrapped_response](const Status& s) {
+ VLOG(1) << "RunGraph::Done";
+ if (!s.ok()) {
+ VLOG(1) << "Bad response from RunGraph:" << s;
+ }
+ call->ClearCancelCallback();
+ delete call_opts;
+ delete wrapped_request;
+ delete wrapped_response;
+ call->SendResponse(ToGrpcStatus(s));
+ };
+
+ auto compute_fn = [this, call_opts, wrapped_request,
+ wrapped_response](StatusCallback done) {
+ worker_->RunGraphAsync(call_opts, wrapped_request, wrapped_response,
+ done);
+ };
+
+ if (cache_) {
+ string request_key = call->request.ShortDebugString();
+ cache_->LookupOrCompute(request_key, RPCResponse(&call->response),
+ compute_fn, done_cb);
+ } else {
+ compute_fn(done_cb);
+ }
});
ENQUEUE_REQUEST(RunGraph, true);
}
@@ -241,16 +264,28 @@
Schedule([this, call]() {
CallOptions* call_opts = new CallOptions;
call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
- worker_->GrpcRecvTensorAsync(
- call_opts, &call->request, &call->response,
- [call, call_opts](const Status& s) {
- call->ClearCancelCallback();
- delete call_opts;
- if (!s.ok()) {
- VLOG(1) << "Bad response from RecvTensor:" << s;
- }
- call->SendResponse(ToGrpcStatus(s));
- });
+
+ auto done_cb = [call, call_opts](const Status& s) {
+ call->ClearCancelCallback();
+ delete call_opts;
+ if (!s.ok()) {
+ VLOG(1) << "Bad response from RecvTensor:" << s;
+ }
+ call->SendResponse(ToGrpcStatus(s));
+ };
+
+ auto compute_fn = [this, &call_opts, &call](StatusCallback done) {
+ worker_->GrpcRecvTensorAsync(call_opts, &call->request, &call->response,
+ done);
+ };
+
+ if (cache_) {
+ string request_key = call->request.ShortDebugString();
+ cache_->LookupOrCompute(request_key, RPCResponse(&call->response),
+ compute_fn, done_cb);
+ } else {
+ compute_fn(done_cb);
+ }
});
EnqueueRecvTensorRequestRaw();
}
@@ -328,6 +363,7 @@
std::unique_ptr<::grpc::ServerCompletionQueue> cq_;
std::unique_ptr<Thread> thread_;
std::unordered_map<int, int> queue_depth_;
+ GrpcResponseCache* cache_;
grpc::WorkerService::AsyncService* const worker_service_;
mutex shutdown_mu_;
@@ -341,9 +377,16 @@
GrpcWorkerServiceOptions options)
: is_shutdown_(false) {
builder->RegisterService(&worker_service_);
- for (int i = 0; i < options.num_worker_threads; i++) {
- threads_.emplace_back(new GrpcWorkerServiceThread(
- worker, builder, options.queue_depth, &worker_service_));
+ if (options.response_cache_bytes > 0) {
+ cache_.reset(
+ new GrpcResponseCache(options.response_cache_bytes,
+ options.response_cache_expires_seconds));
+ }
+
+ for (int i = 0; i < options.num_serving_threads; i++) {
+ threads_.emplace_back(
+ new GrpcWorkerServiceThread(worker, builder, options.queue_depth,
+ cache_.get(), &worker_service_));
}
}
@@ -378,6 +421,7 @@
grpc::WorkerService::AsyncService worker_service_;
std::vector<std::unique_ptr<GrpcWorkerServiceThread>> threads_;
+ std::unique_ptr<GrpcResponseCache> cache_;
mutex service_shutdown_mu_;
bool is_shutdown_ GUARDED_BY(service_shutdown_mu_);
@@ -422,11 +466,14 @@
return;
}
- // Request the tensor associated with the rendezvous key. Any time
- // while waiting for the tensor to be produced, up until the start
- // of execution of the callback lambda body below, an RPC
- // cancellation should abort the rendezvous.
- opts->SetCancelCallback([this, step_id]() { AbortStep(step_id); });
+ // Request the tensor associated with the rendezvous key.
+ // Note that we log the cancellation here but do not abort the current step.
+ // gRPC can generate cancellations in response to transient network failures,
+ // and aborting the step eliminates the opportunity for client side retries.
+ // Repeated client failures will eventually cause the step to be aborted by
+ // the client.
+ opts->SetCancelCallback(
+ [step_id]() { LOG(WARNING) << "RecvTensor cancelled for " << step_id; });
env_->rendezvous_mgr->RecvLocalAsync(
step_id, parsed,
[opts, response, done, src_dev, request](
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h
index 88beb6c..8f2830c 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h
@@ -16,8 +16,10 @@
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_
+#include <memory>
#include <unordered_map>
#include "tensorflow/core/distributed_runtime/recent_request_ids.h"
+#include "tensorflow/core/distributed_runtime/rpc/grpc_response_cache.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h"
#include "tensorflow/core/distributed_runtime/worker.h"
@@ -63,7 +65,9 @@
// Map from GrpcWorkerMethod id to queue depth. If set this overrides the
// default queue depth for a method.
std::unordered_map<int, int> queue_depth;
- int num_worker_threads = 8;
+ int num_serving_threads = 8;
+ int64 response_cache_bytes = 0;
+ int64 response_cache_expires_seconds = 0;
};
// Returns an implementation of WorkerService rpc service.
diff --git a/tensorflow/core/distributed_runtime/tensor_coding.cc b/tensorflow/core/distributed_runtime/tensor_coding.cc
index fe2d1a1..6d20e7c 100644
--- a/tensorflow/core/distributed_runtime/tensor_coding.cc
+++ b/tensorflow/core/distributed_runtime/tensor_coding.cc
@@ -68,13 +68,14 @@
return s;
}
-void TensorResponse::InitPartial(const RecvTensorResponse& response) {
+void TensorResponse::InitPartial(const RecvTensorResponse& response,
+ const AllocationAttributes& allocation_attr) {
// Everything except content is present in *response. Content will
// arrive later; allocate a Tensor with appropriate storage for that
// content.
meta_ = response;
TensorShape shape(meta_.tensor().tensor_shape());
- Tensor t(allocator_, meta_.tensor().dtype(), shape);
+ Tensor t(allocator_, meta_.tensor().dtype(), shape, allocation_attr);
tensor_ = std::move(t);
}
diff --git a/tensorflow/core/distributed_runtime/tensor_coding.h b/tensorflow/core/distributed_runtime/tensor_coding.h
index 4c34297..86d95a3 100644
--- a/tensorflow/core/distributed_runtime/tensor_coding.h
+++ b/tensorflow/core/distributed_runtime/tensor_coding.h
@@ -76,7 +76,8 @@
// Initialize tensor metadata from response and allocate
// uninitialized backing storage for actual contents.
- void InitPartial(const RecvTensorResponse& response);
+ void InitPartial(const RecvTensorResponse& response,
+ const AllocationAttributes& allocation_attr);
// Return a reference to the parsed tensor. The tensor will remain
// live only until *this is destroyed or modified.
diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h
index 3ded86e..4d0c6d4 100644
--- a/tensorflow/core/framework/allocator.h
+++ b/tensorflow/core/framework/allocator.h
@@ -46,6 +46,10 @@
// which Op is performing the allocation, and sets this flag to
// true.
bool allocation_will_be_logged = false;
+ // EXPERIMENTAL: If provided, then evaluates to a timing count such that only
+ // a memory chunk whose last-freed count is at this value or earlier may be
+ // returned.
+ std::function<uint64()> freed_by_func = nullptr;
};
// Runtime statistics collected by an allocator.
diff --git a/tensorflow/core/framework/attr_value_util.cc b/tensorflow/core/framework/attr_value_util.cc
index 79966f0..43b4352 100644
--- a/tensorflow/core/framework/attr_value_util.cc
+++ b/tensorflow/core/framework/attr_value_util.cc
@@ -54,9 +54,7 @@
DCHECK(success);
TensorProto p;
tensor.AsProtoTensorContent(&p);
- string s;
- SerializeToStringDeterministic(p, &s);
- return Hash64(s);
+ return DeterministicProtoHash64(p);
}
// Do not create large tensors in memory, compute hash based on TensorProto
@@ -64,12 +62,8 @@
// different hash code if they are defined with different TensorProto
// representations.
uint64 FastTensorProtoHash(const TensorProto& tp) {
- string s;
if (TensorByteSize(tp) > kMaxAttrValueTensorByteSize) {
- string s;
- bool success = SerializeToStringDeterministic(tp, &s);
- DCHECK(success);
- return Hash64(s);
+ return DeterministicProtoHash64(tp);
} else {
return TensorProtoHash(tp);
}
@@ -95,11 +89,7 @@
TensorProto rhs_tp;
rhs_t.AsProtoTensorContent(&rhs_tp);
- string lhs_str, rhs_str;
- SerializeToStringDeterministic(lhs_tp, &lhs_str);
- SerializeToStringDeterministic(rhs_tp, &rhs_str);
-
- return lhs_str == rhs_str;
+ return AreSerializedProtosEqual(lhs_tp, rhs_tp);
}
// Do not construct large tensors in memory, compare equality using TensorProto
@@ -139,9 +129,7 @@
}
// If `a` is not a tensor or func, get a hash of serialized string.
- string s;
- SerializeToStringDeterministic(a, &s);
- return Hash64(s);
+ return DeterministicProtoHash64(a);
}
bool AreAttrValuesEqual(const AttrValue& a, const AttrValue& b,
@@ -175,10 +163,7 @@
// All other fields in AttrValue have deterministic representations.
// It is safe to compare their serialized strings.
- string a_str, b_str;
- SerializeToStringDeterministic(a, &a_str);
- SerializeToStringDeterministic(b, &b_str);
- return a_str == b_str;
+ return AreSerializedProtosEqual(a, b);
}
string SummarizeString(const string& str) {
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index 83bc950..5c974a7 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -1303,6 +1303,12 @@
c->num_inputs() - 1 /* dim_index */);
}
+Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat) {
+ return ConcatShapeHelper(c, 0 /* start_value_index */,
+ num_inputs_to_concat /* end_value_index */,
+ num_inputs_to_concat /* dim_index */);
+}
+
Status BroadcastBinaryOpOutputShapeFnHelper(InferenceContext* c,
ShapeHandle shape_x,
ShapeHandle shape_y,
diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h
index 14b9688..d421844 100644
--- a/tensorflow/core/framework/common_shape_fns.h
+++ b/tensorflow/core/framework/common_shape_fns.h
@@ -279,6 +279,8 @@
// Shape function for concat operations.
Status ConcatV2Shape(shape_inference::InferenceContext* c);
+Status QuantizedConcatV2Shape(InferenceContext* c, int num_inputs_to_concat);
+
// Shape function for binary operators that broadcast their inputs
// and with output to output_index.
// Note: out cannot be NULL.
diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h
index 321947a..89ba662 100644
--- a/tensorflow/core/framework/device_base.h
+++ b/tensorflow/core/framework/device_base.h
@@ -246,6 +246,15 @@
return errors::Internal("Device does not implement MakeTensorFromProto()");
}
+ // Some devices (i.e. GPUs) may free device memory prior to its actual use
+ // being completed on the assumption that subsequent allocations can only be
+ // used serially with respect to pending uses. If this function returns a
+ // non-zero value it is the value of a device-specific counter such that any
+ // device memory tagged with an earlier freed-at count is really unencumbered
+ // by pending uses. For this to be useful the device memory allocator must
+ // be tagging deallocated memory chunks using the same counter.
+ virtual uint64 SafeAllocFrontier() { return 0; }
+
protected:
// Does not take ownership.
void set_tensorflow_device_thread_pool(thread::ThreadPool* thread_pool) {
diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc
index e369e88..fee5237 100644
--- a/tensorflow/core/framework/node_def_util.cc
+++ b/tensorflow/core/framework/node_def_util.cc
@@ -515,10 +515,13 @@
". (Check whether your GraphDef-interpreting binary is up to date "
"with your GraphDef-generating binary.).");
}
- TF_RETURN_WITH_CONTEXT_IF_ERROR(
- ValidateAttrValue(attr.second, *iter->second),
- "; NodeDef: ", FormatNodeDefForError(node_def), "; ",
- SummarizeOpDef(op_def));
+ // If attr value is placeholder, do not check it.
+ if (attr.second.placeholder().empty()) {
+ TF_RETURN_WITH_CONTEXT_IF_ERROR(
+ ValidateAttrValue(attr.second, *iter->second),
+ "; NodeDef: ", FormatNodeDefForError(node_def), "; ",
+ SummarizeOpDef(op_def));
+ }
// Keep track of which attr names have (not) been found in the NodeDef.
op_attrs.erase(iter);
}
diff --git a/tensorflow/core/framework/op_def_util.cc b/tensorflow/core/framework/op_def_util.cc
index 3597f43..d629719 100644
--- a/tensorflow/core/framework/op_def_util.cc
+++ b/tensorflow/core/framework/op_def_util.cc
@@ -838,20 +838,14 @@
OpDef o2_copy = o2;
o1_copy.clear_attr();
o2_copy.clear_attr();
- string s1, s2;
- SerializeToStringDeterministic(o1_copy, &s1);
- SerializeToStringDeterministic(o2_copy, &s2);
- if (s1 != s2) return false;
- return true;
+ return AreSerializedProtosEqual(o1_copy, o2_copy);
}
uint64 OpDefHash(const OpDef& o) {
uint64 h = RepeatedAttrDefHash(o.attr());
OpDef o_copy = o;
o_copy.clear_attr();
- string s;
- SerializeToStringDeterministic(o_copy, &s);
- return Hash64(s.data(), s.size(), h);
+ return DeterministicProtoHash64(o_copy, h);
}
} // namespace tensorflow
diff --git a/tensorflow/core/framework/reader_base.cc b/tensorflow/core/framework/reader_base.cc
index f84ef0f..ed4ff24 100644
--- a/tensorflow/core/framework/reader_base.cc
+++ b/tensorflow/core/framework/reader_base.cc
@@ -241,7 +241,7 @@
num_records_produced_ = state.num_records_produced();
work_ = state.current_work();
if (work_started_ < 0 || work_finished_ < 0 || num_records_produced_ < 0) {
-#ifdef __ANDROID__
+#if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
const string debug_string = "<debug state not available>";
#else
const string debug_string = state.DebugString();
@@ -251,7 +251,7 @@
debug_string);
}
if (work_started_ > work_finished_) {
-#ifdef __ANDROID__
+#if defined(__ANDROID__) || (__EMSCRIPTEN__)
const string debug_string = "<debug state not available>";
#else
const string debug_string = state.DebugString();
diff --git a/tensorflow/core/framework/step_stats.proto b/tensorflow/core/framework/step_stats.proto
index 67cc9e3..f8cab13 100644
--- a/tensorflow/core/framework/step_stats.proto
+++ b/tensorflow/core/framework/step_stats.proto
@@ -77,6 +77,8 @@
message DeviceStepStats {
string device = 1;
repeated NodeExecStats node_stats = 2;
+ // Its key is thread id.
+ map<uint32, string> thread_names = 3;
}
message StepStats {
diff --git a/tensorflow/core/framework/variant_tensor_data.cc b/tensorflow/core/framework/variant_tensor_data.cc
index c169e86..993a898 100644
--- a/tensorflow/core/framework/variant_tensor_data.cc
+++ b/tensorflow/core/framework/variant_tensor_data.cc
@@ -20,14 +20,10 @@
namespace tensorflow {
-VariantTensorData::VariantTensorData() {}
-
VariantTensorData::VariantTensorData(VariantTensorDataProto proto) {
FromProto(std::move(proto));
}
-VariantTensorData::~VariantTensorData() {}
-
int VariantTensorData::tensors_size() const { return tensors_.size(); }
const Tensor& VariantTensorData::tensors(int index) const {
diff --git a/tensorflow/core/framework/variant_tensor_data.h b/tensorflow/core/framework/variant_tensor_data.h
index ca99e83..d98cf6b 100644
--- a/tensorflow/core/framework/variant_tensor_data.h
+++ b/tensorflow/core/framework/variant_tensor_data.h
@@ -37,11 +37,11 @@
// separate so that kernels do not need to depend on protos.
class VariantTensorData {
public:
- VariantTensorData();
+ VariantTensorData() = default;
+
// TODO(b/118823936): This silently returns if the proto is invalid.
// Consider calling FromProto explicitly instead.
VariantTensorData(VariantTensorDataProto proto);
- ~VariantTensorData();
// Name of the type of objects being serialized.
const string& type_name() const { return type_name_; }
diff --git a/tensorflow/core/framework/variant_test.cc b/tensorflow/core/framework/variant_test.cc
index 08d09de..8947f93 100644
--- a/tensorflow/core/framework/variant_test.cc
+++ b/tensorflow/core/framework/variant_test.cc
@@ -186,7 +186,7 @@
x.Encode(&serialized);
Variant y = TensorList();
- y.Decode(std::move(serialized));
+ y.Decode(serialized);
const TensorList& decoded_vec = *y.get<TensorList>();
for (int i = 0; i < 4; ++i) {
diff --git a/tensorflow/core/graph/mkl_graph_util.h b/tensorflow/core/graph/mkl_graph_util.h
index 990b2fe..f36ca8c 100644
--- a/tensorflow/core/graph/mkl_graph_util.h
+++ b/tensorflow/core/graph/mkl_graph_util.h
@@ -96,7 +96,7 @@
// Restrict quantized ops to QUINT8 and QINT8 for now
if (kernel.find(kMklQuantizedOpLabelPattern) != string::npos) {
- return (T == DT_QUINT8 || T == DT_QINT8);
+ return (T == DT_QUINT8 || T == DT_QINT8 || T == DT_QINT32);
}
// Restrict regular ops to FLOAT
if (kernel.find(kMklOpLabelPattern) != string::npos) {
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 1a8ff34..e934978 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -389,7 +389,7 @@
CopyAttrsConv, AlwaysRewrite});
rinfo_.push_back({csinfo_.depthwise_conv2d,
mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d),
- CopyAttrsConv2DDepthwise, AlwaysRewrite});
+ CopyAttrsConv2DDepthwiseCheckConstFilter, AlwaysRewrite});
rinfo_.push_back(
{csinfo_.depthwise_conv2d_grad_input,
mkl_op_registry::GetMklOpName(csinfo_.depthwise_conv2d_grad_input),
@@ -1495,6 +1495,8 @@
bool change_format = false);
static void CopyAttrsConv2DDepthwise(const Node* orig_node, NodeBuilder* nb,
bool change_format = false);
+ static void CopyAttrsConv2DDepthwiseCheckConstFilter(
+ const Node* orig_node, NodeBuilder* nb, bool change_format = false);
static void CopyAttrsConvCheckConstFilter(const Node* orig_node,
NodeBuilder* nb,
bool change_format = false);
@@ -2231,6 +2233,29 @@
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
+ // Add attributes to new node.
+ nb->Attr("T", T);
+ nb->Attr("strides", strides);
+ nb->Attr("dilations", dilations);
+ nb->Attr("padding", padding);
+ nb->Attr("data_format", data_format);
+}
+
+void MklLayoutRewritePass::CopyAttrsConv2DDepthwiseCheckConstFilter(
+ const Node* orig_node, NodeBuilder* nb, bool change_format) {
+ DataType T;
+ string data_format;
+ string padding;
+ std::vector<int32> strides;
+ std::vector<int32> dilations;
+
+ // Get all attributes from old node.
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
+
Node* filter_node = nullptr;
orig_node->input_node(1, &filter_node);
diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc b/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc
index eb7ee8d..fdc6b79 100644
--- a/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc
+++ b/tensorflow/core/grappler/costs/analytical_cost_estimator_test.cc
@@ -16,7 +16,6 @@
#include "tensorflow/core/grappler/costs/virtual_scheduler.h"
#include "tensorflow/cc/ops/standard_ops.h"
-#include "tensorflow/core/framework/cost_graph.pb.h"
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
#include "tensorflow/core/grappler/costs/analytical_cost_estimator.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -102,7 +101,7 @@
Costs summary;
TF_ASSERT_OK(estimator.PredictCosts(item.graph, &run_metadata, &summary));
- EXPECT_EQ(Costs::NanoSeconds(9151), summary.execution_time);
+ EXPECT_EQ(Costs::NanoSeconds(9157), summary.execution_time);
// Note there are totally 17 nodes (RandomUniform creates 2 nodes), but
// grappler will not process "label", therefore we have 15 here instead
EXPECT_EQ(15, summary.num_ops_total);
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index d0ac87c9..8ec558b 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -1181,7 +1181,7 @@
// elements is larger than the given max size.
for (int i = 0; i < ic->num_outputs(); i++) {
const ShapeHandle& shape_handle = ic->output(i);
- if (!ic->FullyDefined(shape_handle) &&
+ if (!ic->FullyDefined(shape_handle) ||
ic->Value(ic->NumElements(shape_handle)) > max_size) {
return false;
}
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
index 0a7697a..fa6b05b 100644
--- a/tensorflow/core/grappler/costs/graph_properties_test.cc
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -975,6 +975,52 @@
EXPECT_EQ("float: [5,5]", PropToString(out_prop0));
}
+TEST_F(GraphPropertiesTest, SkippingValueInferenceForLargeTensors) {
+ // When using aggressive_shape_inference, we run EvaluateNode() for
+ // whitelisted ops and small input / output tensors. For instance, Fill op is
+ // evaluated and produces output tensor value if output tensor size is smal
+ // (currently, fewer than 17 elements); otherwise we don't run EvalauteNode().
+ // This is to avoid wasting time and memory for producing huge tensors (e.g.,
+ // initializing a large table using Fill.
+ {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), 4, {2}); // 4x4
+ Output b = ops::Const(s.WithOpName("const"), 0.1f, {});
+ // Shape described by a is small; expect output values of Fill op.
+ Output c = ops::Fill(s.WithOpName("fill"), a, b);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(
+ /*assume_valid_feeds=*/false,
+ /*aggressive_shape_inference=*/true));
+ const auto out_props = properties.GetOutputProperties("fill");
+ const OpInfo::TensorProperties out_prop0 = out_props[0];
+ EXPECT_EQ("float: [4,4]", PropToString(out_prop0));
+ EXPECT_TRUE(out_prop0.has_value());
+ }
+ {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), 1000, {4}); // 1000x1000x1000x1000
+ Output b = ops::Const(s.WithOpName("const"), 0.1f, {});
+ // Shape described by a is huge; in that case we skip value inference.
+ // Otherwise, it'd be too much overhead.
+ Output c = ops::Fill(s.WithOpName("fill"), a, b);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(
+ /*assume_valid_feeds=*/false,
+ /*aggressive_shape_inference=*/true));
+ const auto out_props = properties.GetOutputProperties("fill");
+ const OpInfo::TensorProperties out_prop0 = out_props[0];
+ EXPECT_EQ("float: [1000,1000,1000,1000]", PropToString(out_prop0));
+ EXPECT_FALSE(out_prop0.has_value());
+ }
+}
+
TEST_F(GraphPropertiesTest, PackWithConstInput) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("a"), 1, {});
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index 59d20f1..1e2e160 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -27,7 +27,6 @@
namespace grappler {
constexpr int kOpsPerMac = 2;
-constexpr char kConst[] = "Const";
constexpr char kGuaranteeConst[] = "GuaranteeConst";
constexpr char kConv2d[] = "Conv2D";
constexpr char kConv2dBackpropFilter[] = "Conv2DBackpropFilter";
@@ -50,8 +49,6 @@
constexpr char kRecv[] = "_Recv";
constexpr char kSend[] = "_Send";
constexpr char kBatchMatMul[] = "BatchMatMul";
-constexpr char kVariable[] = "Variable";
-constexpr char kVariableV2[] = "VariableV2";
constexpr char kRank[] = "Rank";
constexpr char kShape[] = "Shape";
constexpr char kShapeN[] = "ShapeN";
@@ -68,6 +65,13 @@
constexpr char kFusedBatchNorm[] = "FusedBatchNorm";
constexpr char kFusedBatchNormGrad[] = "FusedBatchNormGrad";
constexpr char kQuantizedMatMulV2[] = "QuantizedMatMulV2";
+// Persistent ops.
+constexpr char kConst[] = "Const";
+constexpr char kVariable[] = "Variable";
+constexpr char kVariableV2[] = "VariableV2";
+constexpr char kAutoReloadVariable[] = "AutoReloadVariable";
+constexpr char kVarHandleOp[] = "VarHandleOp";
+constexpr char kReadVariableOp[] = "ReadVariableOp";
static const Costs::Duration kMinComputeTime(1);
@@ -259,10 +263,6 @@
{kRecv, wrap(&OpLevelCostEstimator::PredictIdentity)},
{kSend, wrap(&OpLevelCostEstimator::PredictIdentity)},
- {kConst, wrap(&OpLevelCostEstimator::PredictVariable)},
- {kVariable, wrap(&OpLevelCostEstimator::PredictVariable)},
- {kVariableV2, wrap(&OpLevelCostEstimator::PredictVariable)},
-
{kRank, wrap(&OpLevelCostEstimator::PredictMetadata)},
{kShape, wrap(&OpLevelCostEstimator::PredictMetadata)},
{kShapeN, wrap(&OpLevelCostEstimator::PredictMetadata)},
@@ -276,6 +276,11 @@
wrap(&OpLevelCostEstimator::PredictFusedBatchNormGrad)},
};
+ persistent_ops_ = {
+ kConst, kVariable, kVariableV2, kAutoReloadVariable,
+ kVarHandleOp, kReadVariableOp,
+ };
+
#define EIGEN_COST(X) Eigen::internal::functor_traits<Eigen::internal::X>::Cost
// Quantize = apply min and max bounds, multiply by scale factor and round.
@@ -363,21 +368,25 @@
Costs OpLevelCostEstimator::PredictCosts(const OpContext& op_context) const {
const auto& op_info = op_context.op_info;
auto it = device_cost_impl_.find(op_info.op());
- if (it == device_cost_impl_.end()) {
- if (elementwise_ops_.find(op_info.op()) != elementwise_ops_.end()) {
- return PredictCwiseOp(op_context);
- }
-
- VLOG(1) << "Missing accurate estimator for op: " << op_info.op();
-
- return PredictCostOfAnUnknownOp(op_context);
+ if (it != device_cost_impl_.end()) {
+ std::function<Costs(const OpContext&)> estimator = it->second;
+ Costs costs = estimator(op_context);
+ VLOG(1) << "Operation " << op_info.op() << " takes "
+ << costs.execution_time.count() << " ns.";
+ return costs;
}
- std::function<Costs(const OpContext&)> estimator = it->second;
- Costs costs = estimator(op_context);
- VLOG(1) << "Operation " << op_info.op() << " takes "
- << costs.execution_time.count() << " ns.";
- return costs;
+ if (persistent_ops_.find(op_info.op()) != persistent_ops_.end()) {
+ return PredictVariable(op_context);
+ }
+
+ if (elementwise_ops_.find(op_info.op()) != elementwise_ops_.end()) {
+ return PredictCwiseOp(op_context);
+ }
+
+ VLOG(1) << "Missing accurate estimator for op: " << op_info.op();
+
+ return PredictCostOfAnUnknownOp(op_context);
}
DeviceInfo OpLevelCostEstimator::GetDeviceInfo(
@@ -1240,7 +1249,7 @@
result.num_ops_with_unknown_shapes = result.inaccurate;
result.compute_time = kMinComputeTime;
- result.execution_time = result.execution_time;
+ result.execution_time = result.compute_time;
return result;
}
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.h b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
index f8ba8c6..ace8fb2 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.h
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.h
@@ -193,6 +193,7 @@
// If true, assume compute and memory overlap; hence, the op cost is max of
// compute_time and memory_time, insteaf of sum of those two.
bool compute_memory_overlap_;
+ std::set<string> persistent_ops_;
private:
friend class OpLevelCostEstimatorTest;
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
index aa0fc9d..04c6ada 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
@@ -499,6 +499,26 @@
OpLevelCostEstimator estimator_;
};
+TEST_F(OpLevelCostEstimatorTest, TestPersistentOpCosts) {
+ OpContext op_context;
+ SetCpuDevice(&op_context.op_info);
+ std::unordered_set<string> persisent_ops = {
+ "Const", "Variable", "VariableV2", "AutoReloadVariable",
+ "VarHandleOp", "ReadVariableOp",
+ };
+ // Minmum cost for all persistent ops.
+ for (const auto& op : persisent_ops) {
+ op_context.op_info.set_op(op);
+ auto cost = estimator_.PredictCosts(op_context);
+ EXPECT_EQ(Costs::Duration(0), cost.memory_time);
+ EXPECT_EQ(Costs::Duration(1), cost.compute_time);
+ EXPECT_EQ(Costs::Duration(1), cost.execution_time);
+ EXPECT_EQ(1, cost.num_ops_total);
+ EXPECT_FALSE(cost.inaccurate);
+ EXPECT_EQ(0, cost.num_ops_with_unknown_shapes);
+ }
+}
+
TEST_F(OpLevelCostEstimatorTest, TestGatherCosts) {
OpContext op_context;
SetCpuDevice(&op_context.op_info);
diff --git a/tensorflow/core/grappler/costs/virtual_placer.cc b/tensorflow/core/grappler/costs/virtual_placer.cc
index 8f5f16e..146eecf 100644
--- a/tensorflow/core/grappler/costs/virtual_placer.cc
+++ b/tensorflow/core/grappler/costs/virtual_placer.cc
@@ -87,6 +87,7 @@
default_device_name_ = devices_.begin()->first; // Any device.
}
}
+ VLOG(3) << "default device name: " << default_device_name_;
// Scan the device names from the cluster, and if there is one job name used,
// use it for canonical device name.
@@ -102,14 +103,15 @@
}
}
}
- // If there is only type of job name in all the devices in the cluster, use
- // that one as default job name; otherwise, use localhost.
+ // If there is only one type of job name in all the devices in the cluster,
+ // use that one as default job name; otherwise, use localhost.
// TODO(dyoon): this should be improved, especially when the cluster is
// composed of multiple worker, PS, and other types of jobs.
if (job_names_from_cluster.size() == 1) {
auto it = job_names_from_cluster.begin();
default_job_name_lowercase_ = *it;
}
+ VLOG(3) << "default job name: " << default_job_name_lowercase_;
}
const DeviceProperties& VirtualPlacer::get_device(const NodeDef& node) const {
diff --git a/tensorflow/core/grappler/costs/virtual_placer.h b/tensorflow/core/grappler/costs/virtual_placer.h
index fee5ce0..e17ece7 100644
--- a/tensorflow/core/grappler/costs/virtual_placer.h
+++ b/tensorflow/core/grappler/costs/virtual_placer.h
@@ -16,7 +16,6 @@
#ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_PLACER_H_
#define TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_PLACER_H_
-#include <unordered_map>
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/device_properties.pb.h"
diff --git a/tensorflow/core/grappler/grappler_item.cc b/tensorflow/core/grappler/grappler_item.cc
index d4bd2cc..bc95c9c 100644
--- a/tensorflow/core/grappler/grappler_item.cc
+++ b/tensorflow/core/grappler/grappler_item.cc
@@ -117,15 +117,10 @@
// Tensorflow functions do not prune stateful or dataset-output ops from
// the function body (see PruneFunctionBody in common_runtime/function.cc).
- //
- // We also keep placeholders in the functions body, because it's a bug to have
- // placeholders inside functions, and we want to catch such invalid graphs
- // early.
- if (optimization_options_.is_function_instantiation) {
+ if (!optimization_options_.allow_pruning_stateful_and_dataset_ops) {
FunctionLibraryDefinition fn_library(OpRegistry::Global(), graph.library());
for (const NodeDef& node : graph.node()) {
- if (IsStateful(node, &fn_library) || IsDataset(node) ||
- IsPlaceholder(node)) {
+ if (IsStateful(node, &fn_library) || IsDataset(node)) {
result.insert(node.name());
}
}
diff --git a/tensorflow/core/grappler/grappler_item.h b/tensorflow/core/grappler/grappler_item.h
index 75712e9..57949b3 100644
--- a/tensorflow/core/grappler/grappler_item.h
+++ b/tensorflow/core/grappler/grappler_item.h
@@ -91,7 +91,13 @@
// by running Grappler optimizer passes. One main difference is that
// functions do not prune ops with side-effects and dataset-output ops (see
// PruneFunctionBody in common_runtime/function.cc).
- bool is_function_instantiation = false;
+ bool allow_pruning_stateful_and_dataset_ops = true;
+
+ // If true Grappler will optimize the main graph, and also all functions in
+ // the graph function library (function can't be polymorphic, it can't have
+ // undefined type parameters in the function signature, or placeholder
+ // attributes in the function body).
+ bool optimize_function_library = true;
};
const std::unordered_set<string>& devices() const;
diff --git a/tensorflow/core/grappler/mutable_graph_view.cc b/tensorflow/core/grappler/mutable_graph_view.cc
index e564310..64098e7 100644
--- a/tensorflow/core/grappler/mutable_graph_view.cc
+++ b/tensorflow/core/grappler/mutable_graph_view.cc
@@ -94,6 +94,18 @@
return CanDedupControlWithRegularInput(graph, *control_node);
}
+bool HasRegularFaninNode(const MutableGraphView& graph, const NodeDef& node,
+ absl::string_view fanin_node_name) {
+ const int num_regular_fanins =
+ graph.NumFanins(node, /*include_controlling_nodes=*/false);
+ for (int i = 0; i < num_regular_fanins; ++i) {
+ if (ParseTensorName(node.input(i)).node() == fanin_node_name) {
+ return true;
+ }
+ }
+ return false;
+}
+
Status MutationError(absl::string_view function_name, absl::string_view params,
absl::string_view msg) {
return errors::InvalidArgument(absl::Substitute(
@@ -168,6 +180,13 @@
return Status::OK();
}
+string GeneratedNameForIdentityConsumingSwitch(
+ const MutableGraphView::OutputPort& fanin) {
+ return AddPrefixToNodeName(
+ absl::StrCat(fanin.node->name(), "_", fanin.port_id),
+ kMutableGraphViewCtrl);
+}
+
} // namespace
void MutableGraphView::AddAndDedupFanouts(NodeDef* node) {
@@ -325,6 +344,60 @@
return Status::OK();
}
+Status MutableGraphView::UpdateNode(
+ absl::string_view node_name, absl::string_view op, absl::string_view device,
+ absl::Span<const std::pair<string, AttrValue>> attrs) {
+ auto error_status = [node_name, op, device, attrs](absl::string_view msg) {
+ std::vector<string> attr_strs;
+ attr_strs.reserve(attrs.size());
+ for (const auto& attr : attrs) {
+ string attr_str = absl::Substitute("('$0', $1)", attr.first,
+ attr.second.ShortDebugString());
+ attr_strs.push_back(attr_str);
+ }
+ string params =
+ absl::Substitute("node_name='$0', op='$1', device='$2', attrs={$3}",
+ node_name, op, device, absl::StrJoin(attr_strs, ", "));
+ return MutationError("UpdateNodeOp", params, msg);
+ };
+
+ NodeDef* node = GetNode(node_name);
+ TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
+
+ MutableGraphView::OutputPort control_port(node, Graph::kControlSlot);
+ auto control_fanouts = GetFanout(control_port);
+ if (op == "Switch" && !control_fanouts.empty()) {
+ return error_status(
+ "can't change node op to Switch when node drives a control dependency "
+ "(alternatively, we could add the identity node needed, but it seems "
+ "like an unlikely event and probably a mistake)");
+ }
+
+ if (node->device() != device) {
+ node->set_device(string(device));
+ }
+ node->mutable_attr()->clear();
+ for (const auto& attr : attrs) {
+ (*node->mutable_attr())[attr.first] = attr.second;
+ }
+
+ if (node->op() == op) {
+ return Status::OK();
+ }
+
+ node->set_op(string(op));
+
+ if (CanDedupControlWithRegularInput(*this, *node)) {
+ for (const auto& control_fanout : control_fanouts) {
+ if (HasRegularFaninNode(*this, *control_fanout.node, node->name())) {
+ RemoveControllingFaninInternal(control_fanout.node, node);
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
Status MutableGraphView::UpdateFanouts(absl::string_view from_node_name,
absl::string_view to_node_name) {
NodeDef* from_node = GetNode(from_node_name);
@@ -546,6 +619,68 @@
return Status::OK();
}
+NodeDef* MutableGraphView::GetControllingFaninToAdd(absl::string_view node_name,
+ const OutputPort& fanin,
+ string* error_msg) {
+ if (!IsSwitch(*fanin.node)) {
+ return fanin.node;
+ } else {
+ TensorId tensor_id(fanin.node->name(), fanin.port_id);
+ if (IsOutputPortControlling(fanin)) {
+ // Can't add a Switch node control dependency.
+ *error_msg = absl::Substitute(
+ "can't add fanin '$0' as it will become a Switch control dependency",
+ tensor_id.ToString());
+ return nullptr;
+ }
+ // We can't anchor control dependencies directly on the switch node: unlike
+ // other nodes only one of the outputs of the switch node will be generated
+ // when the switch node is executed, and we need to make sure the control
+ // dependency is only triggered when the corresponding output is triggered.
+ // We start by looking for an identity node connected to the output of the
+ // switch node, and use it to anchor the control dependency.
+ auto fanouts = GetFanouts(*fanin.node, /*include_controlled_nodes=*/false);
+ for (auto fanout : fanouts) {
+ if (IsIdentity(*fanout.node) || IsIdentityNSingleInput(*fanout.node)) {
+ if (ParseTensorName(fanout.node->input(0)) == tensor_id) {
+ if (fanout.node->name() == node_name) {
+ *error_msg =
+ absl::Substitute("can't add found fanin '$0' to self",
+ AsControlDependency(fanout.node->name()));
+ return nullptr;
+ }
+ return fanout.node;
+ }
+ }
+ }
+
+ // No node found, check if node to be created is itself.
+ if (GeneratedNameForIdentityConsumingSwitch(fanin) == node_name) {
+ *error_msg = absl::Substitute("can't add generated fanin '$0' to self",
+ AsControlDependency(string(node_name)));
+ }
+ }
+ return nullptr;
+}
+
+NodeDef* MutableGraphView::GetOrCreateIdentityConsumingSwitch(
+ const OutputPort& fanin) {
+ // We haven't found an existing node where we can anchor the control
+ // dependency: add a new identity node.
+ string identity_name = GeneratedNameForIdentityConsumingSwitch(fanin);
+ NodeDef* identity_node = GetNode(identity_name);
+ if (identity_node == nullptr) {
+ NodeDef new_node;
+ new_node.set_name(identity_name);
+ new_node.set_op("Identity");
+ new_node.set_device(fanin.node->device());
+ (*new_node.mutable_attr())["T"].set_type(fanin.node->attr().at("T").type());
+ new_node.add_input(TensorIdToString({fanin.node->name(), fanin.port_id}));
+ identity_node = AddNode(std::move(new_node));
+ }
+ return identity_node;
+}
+
Status MutableGraphView::AddControllingFanin(absl::string_view node_name,
const TensorId& fanin) {
auto error_status = [node_name, fanin](absl::string_view msg) {
@@ -561,59 +696,19 @@
NodeDef* fanin_node = GetNode(fanin.node());
TF_RETURN_IF_ERROR(CheckNodeExists(fanin.node(), fanin_node, error_status));
- if (!IsSwitch(*fanin_node)) {
- AddFaninInternal(node, {fanin_node, Graph::kControlSlot});
- } else {
- if (IsTensorIdControlling(fanin)) {
- // Can't add a Switch node control dependency.
- return error_status(absl::Substitute(
- "can't add fanin '$0' as it will become a Switch control dependency",
- fanin.ToString()));
- }
- // We can't anchor control dependencies directly on the switch node: unlike
- // other nodes only one of the outputs of the switch node will be generated
- // when the switch node is executed, and we need to make sure the control
- // dependency is only triggered when the corresponding output is triggered.
- // We start by looking for an identity node connected to the output of the
- // switch node, and use it to anchor the control dependency.
- auto fanouts = GetFanouts(*fanin_node, /*include_controlled_nodes=*/false);
- for (auto fanout : fanouts) {
- if (IsIdentity(*fanout.node) || IsIdentityNSingleInput(*fanout.node)) {
- if (ParseTensorName(fanout.node->input(0)) == fanin) {
- if (fanout.node->name() == node_name) {
- return error_status(
- absl::Substitute("can't add found fanin '$0' to self",
- AsControlDependency(fanout.node->name())));
- }
- AddFaninInternal(node, {fanout.node, Graph::kControlSlot});
- return Status::OK();
- }
- }
- }
- // We haven't found an existing node where we can anchor the control
- // dependency: add a new identity node.
- string ctrl_dep_name = AddPrefixToNodeName(
- absl::StrCat(fanin.node(), "_", fanin.index()), kMutableGraphViewCtrl);
- if (node_name == ctrl_dep_name) {
- return error_status(
- absl::Substitute("can't add generated fanin '$0' to self",
- AsControlDependency(ctrl_dep_name)));
- }
+ OutputPort fanin_port(fanin_node, fanin.index());
- // Reuse a previously created node, if possible.
- NodeDef* ctrl_dep_node = GetNode(ctrl_dep_name);
- if (ctrl_dep_node == nullptr) {
- NodeDef new_node;
- new_node.set_name(ctrl_dep_name);
- new_node.set_op("Identity");
- new_node.set_device(fanin_node->device());
- (*new_node.mutable_attr())["T"].set_type(
- fanin_node->attr().at("T").type());
- new_node.add_input(TensorIdToString(fanin));
- ctrl_dep_node = AddNode(std::move(new_node));
- }
- AddFaninInternal(node, {ctrl_dep_node, Graph::kControlSlot});
+ string error_msg = "";
+ NodeDef* control_node = GetControllingFaninToAdd(
+ node_name, {fanin_node, fanin.index()}, &error_msg);
+ if (!error_msg.empty()) {
+ return error_status(error_msg);
}
+ if (control_node == nullptr) {
+ control_node = GetOrCreateIdentityConsumingSwitch(fanin_port);
+ }
+ AddFaninInternal(node, {control_node, Graph::kControlSlot});
+
return Status::OK();
}
@@ -990,6 +1085,77 @@
return Status::OK();
}
+Status MutableGraphView::UpdateAllRegularFaninsToControlling(
+ absl::string_view node_name) {
+ auto error_status = [node_name](absl::string_view msg) {
+ string params = absl::Substitute("node_name='$0'", node_name);
+ return MutationError("UpdateAllRegularFaninsToControlling", params, msg);
+ };
+
+ NodeDef* node = GetNode(node_name);
+ TF_RETURN_IF_ERROR(CheckNodeExists(node_name, node, error_status));
+
+ const int num_regular_fanins =
+ NumFanins(*node, /*include_controlling_nodes=*/false);
+ std::vector<OutputPort> regular_fanins;
+ regular_fanins.reserve(num_regular_fanins);
+ std::vector<NodeDef*> controlling_fanins;
+ controlling_fanins.reserve(num_regular_fanins);
+
+ // Get all regular fanins and derive controlling fanins.
+ for (int i = 0; i < num_regular_fanins; ++i) {
+ TensorId tensor_id = ParseTensorName(node->input(i));
+ OutputPort fanin_port(nodes()[tensor_id.node()], tensor_id.index());
+
+ string error_msg = "";
+ NodeDef* control_node =
+ GetControllingFaninToAdd(node_name, fanin_port, &error_msg);
+ if (!error_msg.empty()) {
+ return error_status(error_msg);
+ }
+
+ regular_fanins.push_back(fanin_port);
+ controlling_fanins.push_back(control_node);
+ }
+
+ // Replace regular fanins with controlling fanins and dedup.
+ int pos = 0;
+ InputPort input_port(node, Graph::kControlSlot);
+ absl::flat_hash_set<absl::string_view> controls;
+ for (int i = 0; i < num_regular_fanins; ++i) {
+ OutputPort fanin_port = regular_fanins[i];
+ NodeDef* control = controlling_fanins[i];
+ if (control == nullptr) {
+ control = GetOrCreateIdentityConsumingSwitch(fanin_port);
+ }
+ fanouts()[fanin_port].erase({node, i});
+ if (controls.contains(control->name())) {
+ continue;
+ }
+ controls.insert(control->name());
+ node->set_input(pos, AsControlDependency(control->name()));
+ fanouts()[{control, Graph::kControlSlot}].insert(input_port);
+ ++pos;
+ }
+
+ // Shift existing controlling fanins and dedup.
+ for (int i = num_regular_fanins; i < node->input_size(); ++i) {
+ TensorId tensor_id = ParseTensorName(node->input(i));
+ if (controls.contains(tensor_id.node())) {
+ continue;
+ }
+ controls.insert(tensor_id.node());
+ node->mutable_input()->SwapElements(pos, i);
+ ++pos;
+ }
+
+ // Remove duplicate controls and leftover regular fanins.
+ node->mutable_input()->DeleteSubrange(pos, node->input_size() - pos);
+ max_regular_input_port().erase(node);
+
+ return Status::OK();
+}
+
Status MutableGraphView::CheckNodesCanBeDeleted(
const absl::flat_hash_set<string>& nodes_to_delete) {
std::vector<string> missing_nodes;
diff --git a/tensorflow/core/grappler/mutable_graph_view.h b/tensorflow/core/grappler/mutable_graph_view.h
index 16ef832..08a1cd1 100644
--- a/tensorflow/core/grappler/mutable_graph_view.h
+++ b/tensorflow/core/grappler/mutable_graph_view.h
@@ -76,6 +76,14 @@
// underlying graph, which leaves subgraph in valid but undefined state.
Status AddSubgraph(GraphDef&& subgraph);
+ // Updates node `node_name` op, device, and attributes. This will clear any
+ // existing attributes. If it is not possible to update the node or if the
+ // node does not exist, an error will be returned and nothing will be modified
+ // in the graph.
+ Status UpdateNode(absl::string_view node_name, absl::string_view op,
+ absl::string_view device,
+ absl::Span<const std::pair<string, AttrValue>> attrs);
+
// Updates all fanouts (input ports fetching output tensors) from
// `from_node_name` to the `to_node_name`, including control dependencies.
//
@@ -196,6 +204,11 @@
Status SwapRegularFaninsByPorts(absl::string_view node_name, int from_port,
int to_port);
+ // Updates all regular fanins to equivalent controlling fanins. If it is not
+ // possible, an error will be returned and nothing will be modified in the
+ // graph.
+ Status UpdateAllRegularFaninsToControlling(absl::string_view node_name);
+
// Deletes nodes from the graph. If a node can't be safely removed,
// specifically if a node still has fanouts, an error will be returned. Nodes
// that can't be found are ignored.
@@ -241,6 +254,21 @@
// added after existing non control dependency inputs.
bool AddFaninInternal(NodeDef* node, const OutputPort& fanin);
+ // Finds control dependency node to be used based on fanin. If fanin is not a
+ // Switch node, fanin.node is simply returned. Otherwise this will try to find
+ // a candidate Identity node consuming fanin, as the control dependency. If it
+ // is not possible or will introduce a self loop, an error message will be
+ // set. If nullptr is returned with no error
+ // GetOrCreateIdentityConsumingSwitch should be called to generate the new
+ // Identity node.
+ NodeDef* GetControllingFaninToAdd(absl::string_view node_name,
+ const OutputPort& fanin, string* error_msg);
+
+ // Finds a generated Identity node consuming Switch node `fanin.node` at port
+ // `fanin.port_id`. If such a node does not exist, a new Identity node will be
+ // created.
+ NodeDef* GetOrCreateIdentityConsumingSwitch(const OutputPort& fanin);
+
// Removes all instances of regular fanin `fanin` from node `node`.
bool RemoveRegularFaninInternal(NodeDef* node, const OutputPort& fanin);
diff --git a/tensorflow/core/grappler/mutable_graph_view_test.cc b/tensorflow/core/grappler/mutable_graph_view_test.cc
index f4ee4d8..06333d3 100644
--- a/tensorflow/core/grappler/mutable_graph_view_test.cc
+++ b/tensorflow/core/grappler/mutable_graph_view_test.cc
@@ -209,6 +209,102 @@
"different function definition with the same name: XTimesTwo.");
}
+TEST(MutableGraphViewTest, UpdateNodeNoDedupControlDependency) {
+ constexpr char kDevice[] = "/device:foo:0";
+ GraphDef graph_def = test::function::GDef(
+ {NDef("bar_1", "Switch", {}, {}), NDef("bar_2", "Identity", {"bar_1:1"}),
+ NDef("other", "NotImportant", {}, {}),
+ NDef("foo_1", "NotImportant", {"bar_2", "other", "bar_2:1", "^bar_2"}),
+ NDef("foo_2", "NotImportant", {"other:1", "bar_2:2", "^bar_2"})},
+ /*funcs=*/{});
+
+ MutableGraphView graph(&graph_def);
+
+ AttrValue list_value;
+ list_value.mutable_list()->add_type(DT_FLOAT);
+ TF_EXPECT_OK(
+ graph.UpdateNode("bar_2", "IdentityN", kDevice, {{"T", list_value}}));
+
+ CheckNode(graph, "bar_1", "Switch", "", {}, {}, {"bar_2"});
+ CheckNode(graph, "bar_2", "IdentityN", kDevice, {{"T", list_value}},
+ {"bar_1:1"}, {"foo_1", "foo_1:2", "^foo_1", "foo_2:1", "^foo_2"});
+ CheckNode(graph, "other", "NotImportant", "", {}, {}, {"foo_1:1", "foo_2"});
+ CheckNode(graph, "foo_1", "NotImportant", "", {},
+ {"bar_2", "other", "bar_2:1", "^bar_2"}, {});
+ CheckNode(graph, "foo_2", "NotImportant", "", {},
+ {"other:1", "bar_2:2", "^bar_2"}, {});
+
+ CheckGraph(graph);
+}
+
+TEST(MutableGraphViewTest, UpdateNodeDedupControlDependency) {
+ constexpr char kDevice[] = "/device:foo:0";
+ GraphDef graph_def = test::function::GDef(
+ {NDef("bar_1", "Switch", {}, {}), NDef("bar_2", "Identity", {"bar_1:1"}),
+ NDef("other", "NotImportant", {}, {}),
+ NDef("foo_1", "NotImportant", {"bar_2", "other", "bar_2:1", "^bar_2"}),
+ NDef("foo_2", "NotImportant", {"other:1", "bar_2:2", "^bar_2"})},
+ /*funcs=*/{});
+
+ MutableGraphView graph(&graph_def);
+
+ TF_EXPECT_OK(graph.UpdateNode("bar_2", "NotImportant", kDevice, {}));
+
+ CheckNode(graph, "bar_1", "Switch", "", {}, {}, {"bar_2"});
+ CheckNode(graph, "bar_2", "NotImportant", kDevice, {}, {"bar_1:1"},
+ {"foo_1", "foo_1:2", "foo_2:1"});
+ CheckNode(graph, "other", "NotImportant", "", {}, {}, {"foo_1:1", "foo_2"});
+ CheckNode(graph, "foo_1", "NotImportant", "", {},
+ {"bar_2", "other", "bar_2:1"}, {});
+ CheckNode(graph, "foo_2", "NotImportant", "", {}, {"other:1", "bar_2:2"}, {});
+
+ CheckGraph(graph);
+}
+
+TEST(MutableGraphViewTest, UpdateNodeSwitchNoControlDependency) {
+ constexpr char kDevice[] = "/device:foo:0";
+ GraphDef graph_def =
+ test::function::GDef({NDef("foo", "NotImportant", {}, {}),
+ NDef("bar", "NotImportant", {"foo:1"})},
+ /*funcs=*/{});
+
+ MutableGraphView graph(&graph_def);
+
+ TF_EXPECT_OK(graph.UpdateNode("foo", "Switch", kDevice, {}));
+
+ CheckNode(graph, "foo", "Switch", kDevice, {}, {}, {"bar"});
+ CheckNode(graph, "bar", "NotImportant", "", {}, {"foo:1"}, {});
+
+ CheckGraph(graph);
+}
+
+TEST(MutableGraphViewTest, UpdateNodeSwitchControlDependency) {
+ constexpr char kDevice[] = "/device:foo:0";
+ GraphDef graph_def =
+ test::function::GDef({NDef("foo", "NotImportant", {}, {}),
+ NDef("bar", "NotImportant", {"^foo"})},
+ /*funcs=*/{});
+
+ MutableGraphView graph(&graph_def);
+
+ AttrValue attr;
+ attr.set_type(DT_FLOAT);
+ Status s = graph.UpdateNode("foo", "Switch", kDevice, {{"T", attr}});
+ EXPECT_FALSE(s.ok());
+ string expected_msg =
+ "MutableGraphView::UpdateNodeOp(node_name='foo', op='Switch', "
+ "device='/device:foo:0', attrs={('T', type: DT_FLOAT)}) error: can't "
+ "change node op to Switch when node drives a control dependency "
+ "(alternatively, we could add the identity node needed, but it seems "
+ "like an unlikely event and probably a mistake).";
+ EXPECT_EQ(s.error_message(), expected_msg);
+
+ CheckNode(graph, "foo", "NotImportant", "", {}, {}, {"^bar"});
+ CheckNode(graph, "bar", "NotImportant", "", {}, {"^foo"}, {});
+
+ CheckGraph(graph);
+}
+
TEST(MutableGraphViewTest, AddAndUpdateFanouts) {
// Actual node.op() is not important in this test.
GraphDef graph_def = test::function::GDef(
@@ -2216,6 +2312,109 @@
CheckGraph(graph);
}
+void TestUpdateAllRegularFaninsToControlling(
+ absl::string_view node_name, bool node_exists, bool success,
+ const string& error_msg, absl::Span<const string> expected_fanins) {
+ constexpr char kDevice[] = "/device:foo:0";
+ GraphDef graph_def = test::function::GDef(
+ {NDef("a", "NotImportant", {}, {}),
+ NDef("switch", "Switch", {}, {{"T", DT_FLOAT}}, kDevice),
+ NDef("b", "NotImportant", {"switch:1"}, {}),
+ NDef("ConstantFoldingCtrl/switch_1", "Identity", {"switch:1"},
+ {{"T", DT_FLOAT}}, kDevice),
+ NDef("c", "NotImportant", {"a", "^b"}, {}),
+ NDef("d", "NotImportant", {"b", "c"}, {}),
+ NDef("e", "NotImportant", {"^d"}, {})},
+ /*funcs=*/{});
+
+ MutableGraphView graph(&graph_def);
+
+ NodeDef* node = graph.GetNode(node_name);
+ if (node_exists) {
+ EXPECT_NE(node, nullptr);
+ } else {
+ EXPECT_EQ(node, nullptr);
+ }
+
+ absl::flat_hash_map<string, std::vector<string>> unmodified_node_inputs =
+ GetNodeInputsFromGraph(graph_def, node_name);
+
+ Status s = graph.UpdateAllRegularFaninsToControlling(node_name);
+ EXPECT_EQ(s.ok(), success);
+ if (!success) {
+ EXPECT_EQ(s.error_message(), error_msg);
+ }
+ if (node_exists) {
+ CompareNodeFanins(graph, node, expected_fanins);
+ }
+
+ CheckUnmodifiedNodeFanins(graph_def, node_name, unmodified_node_inputs);
+
+ CheckGraph(graph);
+}
+
+TEST(MutableGraphViewTest, UpdateAllRegularFaninsToControlling) {
+ string error_msg;
+ // Nodes with some regular fanins and some controls.
+ TestUpdateAllRegularFaninsToControlling("a", /*node_exists=*/true,
+ /*success=*/true, error_msg, {});
+ TestUpdateAllRegularFaninsToControlling("c", /*node_exists=*/true,
+ /*success=*/true, error_msg,
+ {"^a", "^b"});
+ TestUpdateAllRegularFaninsToControlling("d", /*node_exists=*/true,
+ /*success=*/true, error_msg,
+ {"^b", "^c"});
+ TestUpdateAllRegularFaninsToControlling("e", /*node_exists=*/true,
+ /*success=*/true, error_msg, {"^d"});
+
+ // Use existing Identity to pin control dependency of Switch.
+ TestUpdateAllRegularFaninsToControlling("b", /*node_exists=*/true,
+ /*success=*/true, error_msg,
+ {"^ConstantFoldingCtrl/switch_1"});
+
+ // Missing node.
+ error_msg =
+ "MutableGraphView::UpdateAllRegularFaninsToControlling(node_name='f') "
+ "error: node 'f' was not found.";
+ TestUpdateAllRegularFaninsToControlling("f", /*node_exists=*/false,
+ /*success=*/false, error_msg, {});
+
+ // Error in getting controlling fanin.
+ error_msg =
+ "MutableGraphView::UpdateAllRegularFaninsToControlling(node_name='"
+ "ConstantFoldingCtrl/switch_1') error: can't add found fanin "
+ "'^ConstantFoldingCtrl/switch_1' to self.";
+ TestUpdateAllRegularFaninsToControlling("ConstantFoldingCtrl/switch_1",
+ /*node_exists=*/true,
+ /*success=*/false, error_msg,
+ {"switch:1"});
+}
+
+TEST(MutableGraphViewTest, UpdateAllRegularFaninsToControllingConsumingSwitch) {
+ constexpr char kDevice[] = "/device:foo:0";
+ GraphDef graph_def = test::function::GDef(
+ {NDef("a", "NotImportant", {}, {}),
+ NDef("switch", "Switch", {}, {{"T", DT_FLOAT}}, kDevice),
+ NDef("b", "NotImportant", {"switch:1"}, {})},
+ /*funcs=*/{});
+
+ MutableGraphView graph(&graph_def);
+
+ TF_EXPECT_OK(graph.UpdateAllRegularFaninsToControlling("b"));
+
+ EXPECT_EQ(graph.graph()->node_size(), 4);
+
+ CheckNode(graph, "a", "NotImportant", "", {}, {}, {});
+ CheckNode(graph, "switch", "Switch", kDevice, {{"T", DT_FLOAT}}, {},
+ {"ConstantFoldingCtrl/switch_1"});
+ CheckNode(graph, "b", "NotImportant", "", {},
+ {"^ConstantFoldingCtrl/switch_1"}, {});
+ CheckNode(graph, "ConstantFoldingCtrl/switch_1", "Identity", kDevice,
+ {{"T", DT_FLOAT}}, {"switch:1"}, {"^b"});
+
+ CheckGraph(graph);
+}
+
TEST(MutableGraphViewTest, DeleteNodes) {
// Actual node.op() is not important in this test.
GraphDef graph_def = test::function::GDef(
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 27fba4f..f403e16 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -47,6 +47,12 @@
node.op() == "FloorDiv" || node.op() == "TruncateDiv";
}
+bool IsAnyMaxPool(const NodeDef& node) {
+ const auto& op = node.op();
+ return op == "MaxPool" || op == "MaxPoolV2" || op == "MaxPool3D" ||
+ op == "MaxPoolWithArgmax" || op == "FractionalMaxPool";
+}
+
bool IsApproximateEqual(const NodeDef& node) {
return node.op() == "ApproximateEqual";
}
@@ -164,18 +170,14 @@
bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing) {
static const gtl::FlatSet<string>* const kMonotonicNonDecreasingOps =
CHECK_NOTNULL((new gtl::FlatSet<string>{
- "Asinh", "Atanh", "Ceil", "Elu", "Erf", "Exp", "Expm1",
- "Floor", "Log", "Log1p", "Relu", "Relu", "Relu6", "Rint",
- "Selu", "Sigmoid", "Sign", "Sinh", "Sqrt", "Tanh",
+ "Acosh", "Asin", "Asinh", "Atan", "Atanh", "Ceil",
+ "Elu", "Erf", "Exp", "Expm1", "Floor", "Log",
+ "Log1p", "Relu", "Relu6", "Rint", "Selu", "Sigmoid",
+ "Sign", "Sinh", "Softsign", "Softplus", "Sqrt", "Tanh",
}));
static const gtl::FlatSet<string>* const kMonotonicNonIncreasingOps =
- CHECK_NOTNULL((new gtl::FlatSet<string>{
- "Inv",
- "Reciprocal",
- "Erfc",
- "Rsqrt",
- "Neg",
- }));
+ CHECK_NOTNULL((new gtl::FlatSet<string>{"Acos", "Erfc", "Inv", "Neg",
+ "Reciprocal", "Rsqrt"}));
if (kMonotonicNonDecreasingOps->count(node.op()) > 0) {
if (is_non_decreasing) {
*is_non_decreasing = true;
@@ -320,6 +322,8 @@
return op == "NextIteration" || op == "RefNextIteration";
}
+bool IsOnesLike(const NodeDef& node) { return node.op() == "OnesLike"; }
+
bool IsPack(const NodeDef& node) { return node.op() == "Pack"; }
bool IsPad(const NodeDef& node) {
@@ -341,7 +345,9 @@
bool IsPow(const NodeDef& node) { return node.op() == "Pow"; }
-bool IsPrint(const NodeDef& node) { return node.op() == "Print"; }
+bool IsPrint(const NodeDef& node) {
+ return node.op() == "Print" || node.op() == "PrintV2";
+}
bool IsProd(const NodeDef& node) { return node.op() == "Prod"; }
@@ -538,6 +544,8 @@
return op == "While" || op == "StatelessWhile";
}
+bool IsZerosLike(const NodeDef& node) { return node.op() == "ZerosLike"; }
+
bool IsZeta(const NodeDef& node) { return node.op() == "Zeta"; }
namespace {
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index a1ee253..bc1d8c1 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -28,6 +28,7 @@
bool IsAngle(const NodeDef& node);
bool IsAny(const NodeDef& node);
bool IsAnyDiv(const NodeDef& node);
+bool IsAnyMaxPool(const NodeDef& node);
bool IsApproximateEqual(const NodeDef& node);
bool IsAvgPoolGrad(const NodeDef& node);
bool IsAssert(const NodeDef& node);
@@ -100,6 +101,7 @@
bool IsMul(const NodeDef& node);
bool IsMatMul(const NodeDef& node);
bool IsNextIteration(const NodeDef& node);
+bool IsOnesLike(const NodeDef& node);
bool IsPack(const NodeDef& node);
bool IsPad(const NodeDef& node);
bool IsPack(const NodeDef& node);
@@ -170,6 +172,7 @@
bool IsUnpack(const NodeDef& node);
bool IsVariable(const NodeDef& node);
bool IsWhile(const NodeDef& node);
+bool IsZerosLike(const NodeDef& node);
bool IsZeta(const NodeDef& node);
// Return true if the op is an aggregation (e.g. Add, AddN).
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index cb5d7d6..cbf0d68 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -103,6 +103,7 @@
"//tensorflow/core/grappler/costs:graph_properties",
"//tensorflow/core/grappler/utils:symbolic_shapes",
"@com_google_absl//absl/container:flat_hash_set",
+ "@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 2168dbd..902cb3f 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -2721,7 +2721,7 @@
~OptimizeMaxOrMinOfMonotonicStage() override = default;
bool IsSupported(const NodeDef* node) const override {
- return IsMax(*node) || IsMin(*node);
+ return IsMax(*node) || IsMin(*node) || IsAnyMaxPool(*node);
}
Status TrySimplify(NodeDef* reduction_node,
@@ -2735,10 +2735,13 @@
// 0. inner_function is not in the preserve set,
// 1. inner_function's Op is element-wise monotonic
// 2. inner_function's output is not being consumed elsewhere.
+ // 3. is monotonic increasing if reduction_node is a pooling operation
+ // since we don't have MinPool operations.
bool is_non_decreasing = false;
if (!IsInPreserveSet(*inner_function) &&
IsElementWiseMonotonic(*inner_function, &is_non_decreasing) &&
- ctx().node_map->GetOutputs(inner_function->name()).size() == 1) {
+ ctx().node_map->GetOutputs(inner_function->name()).size() == 1 &&
+ (is_non_decreasing || !IsAnyMaxPool(*reduction_node))) {
// Swap the first inputs of the inner function Op & the reduction Op.
NodeDef* inner_input;
TF_RETURN_IF_ERROR(GetInputNode(inner_function->input(0), &inner_input));
@@ -3239,13 +3242,17 @@
}
private:
- uint64 ComputeSignature(const NodeDef& node) const;
+ uint64 ComputeSignature(const NodeDef& node);
bool SameNode(const NodeDef& node1, const NodeDef& node2) const;
- std::unordered_map<uint64, std::vector<NodeDef*>> rep_;
+ absl::flat_hash_map<uint64, std::vector<NodeDef*>> rep_;
+ absl::flat_hash_map<const NodeDef*, uint64> memoized_signatures_;
};
-uint64 UniqueNodes::ComputeSignature(const NodeDef& node) const {
+uint64 UniqueNodes::ComputeSignature(const NodeDef& node) {
+ auto it = memoized_signatures_.find(&node);
+ if (it != memoized_signatures_.end()) return it->second;
+
uint64 h = Hash64(node.op());
h = Hash64Combine(Hash64(node.device()), h);
@@ -3259,6 +3266,7 @@
h = Hash64CombineUnordered(Hash64(attr.first), h);
h = Hash64CombineUnordered(FastAttrValueHash(attr.second), h);
}
+ memoized_signatures_.emplace(&node, h);
return h;
}
@@ -3279,31 +3287,29 @@
// Compare inputs.
if (IsCommutative(node1)) {
std::vector<string> inputs1(node1.input().begin(), node1.input().end());
- std::vector<string> inputs2(node2.input().begin(), node2.input().end());
std::sort(inputs1.begin(), inputs1.end());
+ std::vector<string> inputs2(node2.input().begin(), node2.input().end());
std::sort(inputs2.begin(), inputs2.end());
return inputs1 == inputs2;
} else {
- std::vector<string> regular_inputs1;
- std::vector<string> regular_inputs2;
- std::vector<string> ctrl_inputs1;
- std::vector<string> ctrl_inputs2;
- for (int index = 0; index < node1.input_size(); ++index) {
+ // The order or ordinary inputs matters.
+ int index = 0;
+ for (; index < node1.input_size(); ++index) {
if (IsControlInput(node1.input(index))) {
- ctrl_inputs1.push_back(node1.input(index));
- ctrl_inputs2.push_back(node2.input(index));
- } else {
- regular_inputs1.push_back(node1.input(index));
- regular_inputs2.push_back(node2.input(index));
+ break;
+ } else if (node1.input(index) != node2.input(index)) {
+ return false;
}
}
- if (regular_inputs1 != regular_inputs2) {
- return false;
- }
- std::sort(ctrl_inputs1.begin(), ctrl_inputs1.end());
- std::sort(ctrl_inputs2.begin(), ctrl_inputs2.end());
- if (ctrl_inputs1 != ctrl_inputs2) {
- return false;
+ // The order of control inputs does not matter.
+ if (index < node1.input_size()) {
+ std::vector<string> ctrl_inputs1(node1.input().begin() + index,
+ node1.input().end());
+ std::sort(ctrl_inputs1.begin(), ctrl_inputs1.end());
+ std::vector<string> ctrl_inputs2(node2.input().begin() + index,
+ node2.input().end());
+ std::sort(ctrl_inputs2.begin(), ctrl_inputs2.end());
+ return ctrl_inputs1 != ctrl_inputs2;
}
}
@@ -3330,8 +3336,8 @@
if (node.device().find("SPU") != string::npos) {
return false;
}
- // Workaround for Assert mistakenly being labeled as stateful.
- if (IsAssert(node)) {
+ // Workaround for Assert and Print mistakenly being labeled as stateful.
+ if (IsAssert(node) || IsPrint(node)) {
return true;
}
return IsFreeOfSideEffect(node);
@@ -3369,9 +3375,9 @@
bool stop = true;
std::set<int> duplicates;
+ UniqueNodes nodes;
do {
stop = true;
- UniqueNodes nodes;
for (int i = 0; i < optimized_graph_->node_size(); ++i) {
if (duplicates.find(i) != duplicates.end()) {
continue;
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 94c59c6..1220aef 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -3561,6 +3561,75 @@
EXPECT_EQ(2, required_node_count);
}
+TEST_F(ArithmeticOptimizerTest,
+ OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasingDoNotChangeMaxPool) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto x = ops::Const(s.WithOpName("x"), 1.5f, {3, 3, 3, 1});
+ Output neg = ops::Neg(s.WithOpName("neg"), x);
+ Output max_pool = ops::MaxPool(s.WithOpName("max_pool"), neg, {1, 2, 2, 1},
+ {1, 2, 2, 1}, "VALID");
+
+ GrapplerItem item;
+ item.fetch = {"max_pool"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ ASSERT_EQ(1, tensors_expected.size());
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
+ OptimizeTwice(&optimizer, &item, &output);
+
+ // Should be a NoOp
+ VerifyGraphsMatch(item.graph, output, __LINE__);
+
+ auto tensors = EvaluateNodes(output, item.fetch);
+ ASSERT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
+}
+
+TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicElementWiseMaxPool) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto x = ops::Const(s.WithOpName("x"), 1.5f, {3, 3, 3, 1});
+ Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
+ Output max_pool = ops::MaxPool(s.WithOpName("max_pool"), sqrt, {1, 2, 2, 1},
+ {1, 2, 2, 1}, "VALID");
+ Output final_out = ops::Identity(s.WithOpName("final_out"), max_pool);
+
+ GrapplerItem item;
+ item.fetch = {"final_out"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(1, tensors_expected.size());
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
+ OptimizeAndPrune(&optimizer, &item, &output);
+ auto tensors = EvaluateNodes(output, item.fetch);
+ EXPECT_EQ(1, tensors.size());
+
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
+ EXPECT_EQ(item.graph.node_size(), output.node_size());
+ // Check if the inputs are switched
+ int required_node_count = 0;
+ for (int i = 0; i < output.node_size(); ++i) {
+ const NodeDef& node = output.node(i);
+ if (node.name() == "sqrt") {
+ EXPECT_EQ("Sqrt", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("max_pool", node.input(0));
+ ++required_node_count;
+ } else if (node.name() == "max_pool") {
+ EXPECT_EQ("MaxPool", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ ++required_node_count;
+ }
+ }
+ EXPECT_EQ(2, required_node_count);
+}
+
TEST_F(ArithmeticOptimizerTest, UnaryOpsComposition) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index b0c3c5b..f883f89 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -17,6 +17,7 @@
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
+#include "absl/strings/string_view.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.pb.h"
@@ -716,6 +717,61 @@
return Status::OK();
}
+Status ConstantFolding::MaterializeConstantValuedNode(
+ NodeDef* node, const GraphProperties& properties) {
+ // Nodes that generate constant-valued outputs can be represented compactly in
+ // compressed format, regardless of their shape.
+ const std::vector<OpInfo::TensorProperties>& output_props =
+ properties.GetOutputProperties(node->name());
+ if (output_props.size() != 1) return Status::OK();
+ const auto& output_shape = output_props[0].shape();
+ if (!PartialTensorShape(output_shape).IsFullyDefined()) {
+ return Status::OK();
+ }
+ if (IsFill(*node)) {
+ const auto output_dtype = output_props[0].dtype();
+ NodeDef* input_node = nullptr;
+ for (int i = 0; i < 2; ++i) {
+ input_node = node_map_->GetNode(NodeName(node->input(i)));
+ if (input_node == nullptr || !IsReallyConstant(*input_node)) {
+ return Status::OK();
+ }
+ }
+ TF_RETURN_IF_ERROR(CheckAttrExists(*input_node, "value"));
+ const TensorProto& input_tensor = input_node->attr().at("value").tensor();
+ // TODO(rmlarsen): Handle the case where the value is stored in
+ // tensor_content.
+ if (!input_tensor.tensor_content().empty()) {
+ return Status::OK();
+ }
+ TensorProto* tensor = (*node->mutable_attr())["value"].mutable_tensor();
+ // Copy the input tensor to the fill node, set the output shape, and
+ // change the nodd type to Const.
+ *tensor = input_tensor;
+ *(tensor->mutable_tensor_shape()) = output_shape;
+ (*node->mutable_attr())["dtype"].set_type(output_dtype);
+ node->mutable_attr()->erase("T");
+ node->mutable_attr()->erase("index_type");
+ node->set_op("Const");
+ for (int i = 0; i < 2; i++) {
+ // Change inputs to a control inputs.
+ const string ctrl_dep = AsControlDependency(node->input(i));
+ node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
+ node->set_input(i, ctrl_dep);
+ }
+ graph_modified_ = true;
+ } else {
+ double value =
+ (IsZerosLike(*node) ? 0.0 : (IsOnesLike(*node) ? 1.0 : -1.0));
+ bool success = false;
+ if (value >= 0) {
+ TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
+ value, properties, output_shape, node, graph_, &success));
+ }
+ }
+ return Status::OK();
+}
+
Status ConstantFolding::MaterializeConstants(
const GraphProperties& properties) {
const int node_count = graph_->node_size();
@@ -726,6 +782,8 @@
TF_RETURN_IF_ERROR(MaterializeBroadcastGradientArgs(node, properties));
} else if (IsReduction(node)) {
TF_RETURN_IF_ERROR(MaterializeReductionIndices(&node, properties));
+ } else if (IsFill(node) || IsZerosLike(node) || IsOnesLike(node)) {
+ TF_RETURN_IF_ERROR(MaterializeConstantValuedNode(&node, properties));
}
}
return Status::OK();
@@ -1059,98 +1117,103 @@
return Status::OK();
}
-Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output_graph,
- bool* result_too_large) {
- if (IsMerge(*node)) {
- // Merge nodes are special, in the sense that they execute as soon as one of
- // their input is ready. We can therefore fold a merge node iff it has at
- // least one constant input without control dependency.
- // We still need to ensure that the nodes in the fanin of the merge node are
- // scheduled. We'll therefore add a control dependency from the merge node
- // to the folded constant. We end up with:
- // * the merge node and its inputs are preserved as is
- // * a new constant node C1, driven by the merge node through a control
- // dependency, initialized to the value of the folded input
- // * a new constant node C2, driven by the merge node through a control
- // dependency, initialized to the index of the folded input
- // * the fanout of the merge nodes is rewired to be driven by either C1 or
- // C2.
- for (int input_index = 0; input_index < node->input_size(); ++input_index) {
- const auto& input = node->input(input_index);
- if (IsControlInput(input)) {
- // Try the next input.
- continue;
+Status ConstantFolding::FoldMergeNode(NodeDef* node, GraphDef* output_graph) {
+ // Merge nodes are special, in the sense that they execute as soon as one of
+ // their input is ready. We can therefore fold a merge node iff it has at
+ // least one constant input without control dependency.
+ // We still need to ensure that the nodes in the fanin of the merge node are
+ // scheduled. We'll therefore add a control dependency from the merge node
+ // to the folded constant. We end up with:
+ // * the merge node and its inputs are preserved as is
+ // * a new constant node C1, driven by the merge node through a control
+ // dependency, initialized to the value of the folded input
+ // * a new constant node C2, driven by the merge node through a control
+ // dependency, initialized to the index of the folded input
+ // * the fanout of the merge nodes is rewired to be driven by either C1 or
+ // C2.
+ for (int input_index = 0; input_index < node->input_size(); ++input_index) {
+ const auto& input = node->input(input_index);
+ if (IsControlInput(input)) {
+ // Try the next input.
+ continue;
+ }
+ NodeDef* input_node = node_map_->GetNode(input);
+ if (!IsReallyConstant(*input_node)) {
+ continue;
+ }
+ bool valid_input = true;
+ for (const string& fanin_of_input : input_node->input()) {
+ if (IsControlInput(fanin_of_input)) {
+ valid_input = false;
+ break;
}
- NodeDef* input_node = node_map_->GetNode(input);
- if (!IsReallyConstant(*input_node)) {
- continue;
- }
- bool valid_input = true;
- for (const string& fanin_of_input : input_node->input()) {
- if (IsControlInput(fanin_of_input)) {
- valid_input = false;
- break;
- }
- }
- if (!valid_input) {
- // Try the next input
- continue;
- }
+ }
+ if (!valid_input) {
+ // Try the next input
+ continue;
+ }
- string const_out_name = OptimizedNodeName(*node, "_const");
- string const_index_name = OptimizedNodeName(*node, "_index");
- if (node_map_->GetNode(const_out_name) ||
- node_map_->GetNode(const_index_name)) {
- // Intended name already exists.
- return errors::AlreadyExists(
- strings::StrCat(const_out_name, " or ", const_index_name,
- " already present in the graph"));
- }
+ string const_out_name = OptimizedNodeName(*node, "_const");
+ string const_index_name = OptimizedNodeName(*node, "_index");
+ if (node_map_->GetNode(const_out_name) ||
+ node_map_->GetNode(const_index_name)) {
+ // Intended name already exists.
+ return errors::AlreadyExists(
+ strings::StrCat(const_out_name, " or ", const_index_name,
+ " already present in the graph"));
+ }
- NodeDef* const_out = output_graph->add_node();
- *const_out = *input_node;
- const_out->set_name(const_out_name);
- const_out->set_device(node->device());
- *const_out->add_input() = AsControlDependency(*node);
- node_map_->AddNode(const_out->name(), const_out);
- node_map_->AddOutput(node->name(), const_out->name());
+ NodeDef* const_out = output_graph->add_node();
+ *const_out = *input_node;
+ const_out->set_name(const_out_name);
+ const_out->set_device(node->device());
+ *const_out->add_input() = AsControlDependency(*node);
+ node_map_->AddNode(const_out->name(), const_out);
+ node_map_->AddOutput(node->name(), const_out->name());
- NodeDef* const_index = output_graph->add_node();
- const_index->set_op("Const");
- Tensor index(DT_INT32, TensorShape({}));
- index.flat<int32>()(0) = input_index;
- (*const_index->mutable_attr())["dtype"].set_type(DT_INT32);
- index.AsProtoTensorContent(
- (*const_index->mutable_attr())["value"].mutable_tensor());
- const_index->set_name(const_index_name);
- const_index->set_device(node->device());
- *const_index->add_input() = AsControlDependency(*node);
- node_map_->AddNode(const_index->name(), const_index);
- node_map_->AddOutput(node->name(), const_index->name());
+ NodeDef* const_index = output_graph->add_node();
+ const_index->set_op("Const");
+ Tensor index(DT_INT32, TensorShape({}));
+ index.flat<int32>()(0) = input_index;
+ (*const_index->mutable_attr())["dtype"].set_type(DT_INT32);
+ index.AsProtoTensorContent(
+ (*const_index->mutable_attr())["value"].mutable_tensor());
+ const_index->set_name(const_index_name);
+ const_index->set_device(node->device());
+ *const_index->add_input() = AsControlDependency(*node);
+ node_map_->AddNode(const_index->name(), const_index);
+ node_map_->AddOutput(node->name(), const_index->name());
- auto outputs = node_map_->GetOutputs(node->name());
- for (NodeDef* output : outputs) {
- for (int i = 0; i < output->input_size(); i++) {
- int port;
- string node_name = ParseNodeName(output->input(i), &port);
- if (node_name == node->name()) {
- if (port == 0) {
- *output->mutable_input(i) = const_out->name();
- node_map_->AddOutput(const_out->name(), output->name());
- } else if (port == 1) {
- *output->mutable_input(i) = const_index->name();
- node_map_->AddOutput(const_index->name(), output->name());
- } else {
- // This is a control dependency (or an invalid edge since the
- // merge node has only 2 inputs): preserve them.
- }
+ auto outputs = node_map_->GetOutputs(node->name());
+ for (NodeDef* output : outputs) {
+ for (int i = 0; i < output->input_size(); i++) {
+ int port;
+ string node_name = ParseNodeName(output->input(i), &port);
+ if (node_name == node->name()) {
+ if (port == 0) {
+ *output->mutable_input(i) = const_out->name();
+ node_map_->AddOutput(const_out->name(), output->name());
+ } else if (port == 1) {
+ *output->mutable_input(i) = const_index->name();
+ node_map_->AddOutput(const_index->name(), output->name());
+ } else {
+ // This is a control dependency (or an invalid edge since the
+ // merge node has only 2 inputs): preserve them.
}
}
}
- return Status::OK();
}
return Status::OK();
}
+ return Status::OK();
+}
+
+Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output_graph,
+ bool* result_too_large) {
+ *result_too_large = false;
+ if (IsMerge(*node)) {
+ return FoldMergeNode(node, output_graph);
+ }
std::vector<NodeDef> const_nodes;
TF_RETURN_IF_ERROR(
@@ -1395,7 +1458,8 @@
if (feed_nodes_.find(node.name()) != feed_nodes_.end()) {
return false;
}
- if (node.op() == "OnesLike") return true;
+ if (IsOnesLike(node)) return true;
+ if (IsZerosLike(node)) return false;
if (node.op() == "Fill") {
NodeDef* values = node_map_->GetNode(NodeName(node.input(1)));
return values != nullptr && IsOnes(*values);
@@ -1428,7 +1492,8 @@
if (feed_nodes_.find(node.name()) != feed_nodes_.end()) {
return false;
}
- if (node.op() == "ZerosLike") return true;
+ if (IsOnesLike(node)) return false;
+ if (IsZerosLike(node)) return true;
if (node.op() == "Fill") {
NodeDef* values = node_map_->GetNode(NodeName(node.input(1)));
return values != nullptr && IsZeros(*values);
@@ -1562,6 +1627,7 @@
node->set_input(i, ctrl_dep);
}
*success = true;
+ graph_modified_ = true;
return Status::OK();
}
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h
index 9920092..7cf01b4 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.h
+++ b/tensorflow/core/grappler/optimizers/constant_folding.h
@@ -67,8 +67,10 @@
const GraphProperties& properties);
Status MaterializeReductionIndices(NodeDef* node,
const GraphProperties& properties);
-
+ Status MaterializeConstantValuedNode(NodeDef* node,
+ const GraphProperties& properties);
Status MaterializeConstants(const GraphProperties& properties);
+
bool IsFoldable(const NodeDef& node) const;
Status EvaluateNode(const NodeDef& node,
@@ -78,6 +80,7 @@
Status EvaluateOneFoldable(const NodeDef& node, std::vector<NodeDef>* outputs,
bool* result_too_large);
+ Status FoldMergeNode(NodeDef* node, GraphDef* output_graph);
Status FoldNode(NodeDef* node, GraphDef* output_graph,
bool* result_too_large);
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index d7cabf5..81d00fa 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -378,7 +378,7 @@
const string ones_name = strings::StrCat("ones", suffix);
const string ctrl_zeros_name = strings::StrCat("^zeros", suffix);
const string ctrl_ones_name = strings::StrCat("^ones", suffix);
- EXPECT_EQ(27, output.node_size());
+ EXPECT_EQ(const_type == kFill ? 31 : 27, output.node_size());
for (int i = 0; i < output.node_size(); ++i) {
const NodeDef& node = output.node(i);
const string& name = node.name();
@@ -3466,6 +3466,55 @@
}
}
+TEST_F(ConstantFoldingTest, MaterializeConstantValuedNode) {
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+ Output x =
+ ops::Placeholder(scope.WithOpName("x"), DT_FLOAT,
+ ops::Placeholder::Shape(TensorShape({1, 2, 3, 4})));
+ Output ones_like = ops::OnesLike(scope.WithOpName("ones_like"), x);
+ Output zeros_like = ops::ZerosLike(scope.WithOpName("zeros_like"), x);
+ Output fill = ops::Fill(scope.WithOpName("fill"), {4, 3, 2, 1}, 42);
+
+ GrapplerItem item;
+ TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+ item.fetch = {"ones_like", "zeros_like", "fill"};
+ auto x_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 2, 3, 4}));
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch, {{"x", x_t}});
+
+ ConstantFolding optimizer(/*opt_level=*/RewriterConfig::AGGRESSIVE,
+ /*cpu_device=*/nullptr);
+ GraphDef output;
+ Status status = optimizer.Optimize(/*cluster=*/nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ EXPECT_EQ(output.node_size(), 6);
+ for (const auto& node : output.node()) {
+ if (node.name() != "x") {
+ EXPECT_EQ(node.op(), "Const");
+ }
+ if (node.name() == "ones_like" || node.name() == "zeros_like") {
+ ASSERT_EQ(node.input_size(), 1);
+ EXPECT_EQ(node.input(0), "^x");
+ }
+ if (node.name() == "fill") {
+ ASSERT_EQ(node.input_size(), 2);
+ EXPECT_EQ(node.input(0)[0], '^');
+ EXPECT_EQ(node.input(1)[0], '^');
+ }
+ }
+ auto tensors = EvaluateNodes(output, item.fetch, {{"x", x_t}});
+ ASSERT_EQ(item.fetch.size(), tensors.size());
+ ASSERT_EQ(tensors_expected.size(), tensors.size());
+ for (int i = 0; i < tensors.size(); i++) {
+ if (item.fetch[i] == "fill") {
+ test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
+ } else {
+ test::ExpectTensorEqual<float>(tensors_expected[i], tensors[i]);
+ }
+ }
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
index 682e7cd..ef02962 100644
--- a/tensorflow/core/grappler/optimizers/data/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -6,6 +6,7 @@
package(default_visibility = [
"//tensorflow/core/grappler/optimizers/data:__subpackages__",
"//tensorflow/core/kernels/data:__pkg__",
+ "//tensorflow/core/kernels/data/experimental:__pkg__",
])
cc_library(
@@ -541,6 +542,20 @@
)
cc_library(
+ name = "rebatch",
+ srcs = ["rebatch.cc"],
+ hdrs = ["rebatch.h"],
+ deps = [
+ ":graph_utils",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:mutable_graph_view",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ ],
+ alwayslink = 1,
+)
+
+cc_library(
name = "noop_elimination",
srcs = ["noop_elimination.cc"],
hdrs = [
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
index cbafb9d..7bcc12c 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
@@ -232,6 +232,13 @@
return graph.GetRegularFanin(input_port).node;
}
+NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph,
+ int64 i) {
+ if (node.input_size() <= i) return nullptr;
+ MutableGraphView::InputPort input_port = graph.GetInputPort(node.name(), i);
+ return graph.GetRegularFanin(input_port).node;
+}
+
void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph,
NodeDef* node) {
string name = string(prefix);
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h
index 8f2872c..22298cc 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h
@@ -108,6 +108,10 @@
// Gets the 0th input to a node in the graph.
NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph);
+// Gets the ith input to a node in the graph.
+NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph,
+ int64 i);
+
// Returns the list of indices of all nodes with the given op or empty list if
// no such node exists.
std::vector<int> FindAllGraphNodesWithOp(const string& op,
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
index 3b6d223..879cecd 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
@@ -228,6 +228,21 @@
EXPECT_EQ(GetInputNode(*node1, graph), nullptr);
}
+TEST(GraphUtilsTest, GetIthInputNode) {
+ GraphDef graph_def;
+ MutableGraphView graph(&graph_def);
+
+ NodeDef* node1 = AddNode("", "A", {}, {}, &graph);
+ NodeDef* node2 = AddNode("", "A", {}, {}, &graph);
+ NodeDef* node3 = AddNode("", "A", {node1->name(), node2->name()}, {}, &graph);
+
+ EXPECT_EQ(GetInputNode(*node3, graph), node1);
+ EXPECT_EQ(GetInputNode(*node3, graph, 1), node2);
+ EXPECT_EQ(GetInputNode(*node3, graph, 0), node1);
+ EXPECT_EQ(GetInputNode(*node3, graph, 2), nullptr);
+ EXPECT_EQ(GetInputNode(*node1, graph), nullptr);
+}
+
TEST(GraphUtilsTest, EnsureNodeNamesUnique) {
Graph g(OpRegistry::Global());
diff --git a/tensorflow/core/grappler/optimizers/data/latency_all_edges.cc b/tensorflow/core/grappler/optimizers/data/latency_all_edges.cc
index 4529b89..9bff068 100644
--- a/tensorflow/core/grappler/optimizers/data/latency_all_edges.cc
+++ b/tensorflow/core/grappler/optimizers/data/latency_all_edges.cc
@@ -37,8 +37,8 @@
NodeDef MakeLatencyNode(const NodeDef& node, MutableGraphView* graph) {
NodeDef new_node;
new_node.set_op(kInsertOpName);
- graph_utils::SetUniqueGraphNodeName(
- strings::StrCat(kInsertOpName, "_generated"), graph->graph(), &new_node);
+ graph_utils::SetUniqueGraphNodeName(strings::StrCat(kInsertOpName),
+ graph->graph(), &new_node);
// Set the input of LatencyDataset node as `node`
new_node.add_input(node.name());
@@ -75,8 +75,7 @@
// TODO(shivaniagrawal): Add Op to return Latency for the particular Op than
// for the edge (e2 - e1?).
for (const NodeDef& node : item.graph.node()) {
- if (!str_util::EndsWith(node.op(), "Dataset") || node.attr().empty() ||
- str_util::EndsWith(node.name(), "_generated")) {
+ if (!str_util::EndsWith(node.op(), "Dataset") || node.attr().empty()) {
// TODO(b/111805951): Replace this with non-approximate way to check if
// node corresponds to a `Dataset` op.
continue;
@@ -87,15 +86,8 @@
if (fanout.size() > 1) {
LOG(WARNING) << node.name() << " has fanout size " << fanout.size();
continue;
- } else { // fanout will have size 0 for last dataset node in the pipeline.
- if (fanout.size() == 1) {
- NodeDef* output_node = (*(fanout.begin())).node;
- if (str_util::EndsWith(output_node->name(), "_generated")) {
- continue;
- }
- }
}
-
+ // fanout will have size 0 for last dataset node in the pipeline.
NodeDef* latency_node = graph.AddNode(MakeLatencyNode(node, &graph));
TF_RETURN_IF_ERROR(graph.UpdateFanouts(node.name(), latency_node->name()));
stats->num_changes++;
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
index 13933b6..5c8f780 100644
--- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
@@ -224,9 +224,18 @@
auto& output_shapes_attr = (*batch_node.mutable_attr())["output_shapes"];
const auto& input_shapes =
input_node.attr().at("output_shapes").list().shape();
- int64 batch_size =
- old_batch_node.attr().at("output_shapes").list().shape()[0].dim(0).size();
+
+ int64 batch_size = -1;
+ for (const auto& shape :
+ old_batch_node.attr().at("output_shapes").list().shape()) {
+ if (!shape.unknown_rank()) {
+ batch_size = shape.dim(0).size();
+ break;
+ }
+ }
+
for (size_t i = 0; i < input_shapes.size(); ++i) {
+ // Note: We already checked earlier that input shapes are all fully defined.
TensorShapeProto* shape = output_shapes_attr.mutable_list()->add_shape();
TensorShapeProto_Dim* dim = shape->add_dim();
dim->set_size(batch_size);
diff --git a/tensorflow/core/grappler/optimizers/data/rebatch.cc b/tensorflow/core/grappler/optimizers/data/rebatch.cc
new file mode 100644
index 0000000..187e1a6
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/rebatch.cc
@@ -0,0 +1,115 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/rebatch.h"
+
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+
+namespace tensorflow {
+namespace grappler {
+
+Status RebatchOptimizer::Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) {
+ if (!config) return Status::OK();
+
+ num_workers_ = config->parameter_map().at("num_workers").i();
+ return Status::OK();
+}
+
+namespace {
+
+constexpr char kCastOp[] = "Cast";
+constexpr char kRealDivOp[] = "RealDiv";
+constexpr char kBatchDatasetOp[] = "BatchDatasetV2";
+
+NodeDef* AddCastNode(const string& input, DataType src_t, DataType dst_t,
+ MutableGraphView* graph) {
+ NodeDef cast_node;
+ cast_node.set_op(kCastOp);
+ cast_node.add_input(input);
+ graph_utils::SetUniqueGraphNodeName(cast_node.op(), graph->graph(),
+ &cast_node);
+ AddNodeAttr("SrcT", src_t, &cast_node);
+ AddNodeAttr("DstT", dst_t, &cast_node);
+
+ return graph->AddNode(std::move(cast_node));
+}
+
+NodeDef* AddBinaryNode(const string& input_x, const string& input_y,
+ const string& op, DataType type,
+ MutableGraphView* graph) {
+ NodeDef node;
+ node.set_op(op);
+ node.add_input(input_x);
+ node.add_input(input_y);
+ graph_utils::SetUniqueGraphNodeName(op, graph->graph(), &node);
+ AddNodeAttr("T", type, &node);
+
+ return graph->AddNode(std::move(node));
+}
+
+NodeDef* AddFloatDivNode(const string& input_x, const string& input_y,
+ MutableGraphView* graph) {
+ return AddBinaryNode(input_x, input_y, kRealDivOp, DT_FLOAT, graph);
+}
+
+} // anonymous namespace
+
+Status RebatchOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) {
+ *output = item.graph;
+ MutableGraphView graph(output);
+
+ absl::flat_hash_set<string> nodes_to_delete;
+ for (const NodeDef& node : item.graph.node()) {
+ if (node.op() == kBatchDatasetOp) {
+ NodeDef* batch_size_node = graph_utils::GetInputNode(node, graph, 1);
+ NodeDef tmp_node;
+ tmp_node = *batch_size_node;
+ graph_utils::SetUniqueGraphNodeName(tmp_node.op(), graph.graph(),
+ &tmp_node);
+ NodeDef* copy_batch_size_node = graph.AddNode(std::move(tmp_node));
+ NodeDef* float_copy_batch_size_node =
+ AddCastNode(copy_batch_size_node->name(), DT_INT64, DT_FLOAT, &graph);
+ NodeDef* num_worker_node =
+ graph_utils::AddScalarConstNode<int64>(num_workers_, &graph);
+ NodeDef* float_num_worker_node =
+ AddCastNode(num_worker_node->name(), DT_INT64, DT_FLOAT, &graph);
+ NodeDef* divided_batch_size_node =
+ AddFloatDivNode(float_copy_batch_size_node->name(),
+ float_num_worker_node->name(), &graph);
+ NodeDef* cast_new_batch_size_node = AddCastNode(
+ divided_batch_size_node->name(), DT_FLOAT, DT_INT64, &graph);
+ TF_RETURN_IF_ERROR(graph.UpdateFanouts(batch_size_node->name(),
+ cast_new_batch_size_node->name()));
+ nodes_to_delete.insert(batch_size_node->name());
+ break;
+ }
+ }
+ TF_RETURN_IF_ERROR(graph.DeleteNodes(nodes_to_delete));
+ return Status::OK();
+}
+
+void RebatchOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output,
+ double result) {}
+
+REGISTER_GRAPH_OPTIMIZER_AS(RebatchOptimizer, "tf_data_rebatcher");
+
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/rebatch.h b/tensorflow/core/grappler/optimizers/data/rebatch.h
new file mode 100644
index 0000000..f7aa69f
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/rebatch.h
@@ -0,0 +1,52 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_REBATCH_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_REBATCH_H_
+
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// This optimizer changes the batch size of the output dataset by dividing the
+// current batch size by parameter `num_workers`. Currently, this works only
+// for very simple pipelines with a single BatchDatasetV2 transformation.
+//
+// TODO(rohanj): Extend this logic to correctly handle any input pipeline that
+// uses core tf.data APIs + MapAndBatch.
+class RebatchOptimizer : public CustomGraphOptimizer {
+ public:
+ RebatchOptimizer() = default;
+ ~RebatchOptimizer() override = default;
+
+ string name() const override { return "tf_data_rebatcher"; }
+
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override;
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) override;
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output, double result) override;
+
+ private:
+ int64 num_workers_;
+};
+
+} // namespace grappler
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_REBATCH_H_
diff --git a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc
index ff64ff1..0563460 100644
--- a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc
@@ -16,7 +16,6 @@
#include "tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.h"
#include "absl/container/flat_hash_set.h"
-#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/grappler/clusters/cluster.h"
#include "tensorflow/core/grappler/grappler_item.h"
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
index 0eee91f..0f34d2b 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
@@ -14,7 +14,6 @@
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
-#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/platform/test.h"
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index c309caa..33942a5 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -124,7 +124,8 @@
MK_OPT("scoped_allocator",
new ScopedAllocatorOptimizer(cfg_.scoped_allocator_optimization(),
cfg_.scoped_allocator_opts()));
- MK_OPT("small_op", new PinToHostOptimizer(cfg_.pin_to_host_optimization()));
+ MK_OPT("pin_to_host",
+ new PinToHostOptimizer(cfg_.pin_to_host_optimization()));
return std::unique_ptr<GraphOptimizer>();
}
@@ -164,7 +165,7 @@
if (cfg_.remapping() != RewriterConfig::OFF) {
optimizers->push_back(MakeUnique<Remapper>(cfg_.remapping()));
}
- if (cfg_.pin_to_host_optimization() == RewriterConfig::ON) {
+ if (cfg_.pin_to_host_optimization() != RewriterConfig::OFF) {
optimizers->push_back(MakeUnique<PinToHostOptimizer>());
}
if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) {
@@ -552,7 +553,8 @@
// Optimize each function only once.
absl::flat_hash_set<string> optimized_funcs;
- bool optimize_function_library = true;
+ bool optimize_function_library =
+ item.optimization_options().optimize_function_library;
while (optimize_function_library) {
optimize_function_library = false;
@@ -604,7 +606,8 @@
// instantiated by the function definition, because we must guarantee
// function execution semantics wrt side effects (see
// function_optimizer.cc).
- func_item.optimization_options().is_function_instantiation = true;
+ func_item.optimization_options().allow_pruning_stateful_and_dataset_ops =
+ false;
// Optimize function body graph.
GraphDef optimized_func_graph;
@@ -679,7 +682,7 @@
rewrite_cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT ||
rewrite_cfg.debug_stripper() == RewriterConfig::ON ||
rewrite_cfg.scoped_allocator_optimization() == RewriterConfig::ON ||
- rewrite_cfg.pin_to_host_optimization() == RewriterConfig::ON ||
+ rewrite_cfg.pin_to_host_optimization() != RewriterConfig::OFF ||
!rewrite_cfg.optimizers().empty() ||
!rewrite_cfg.custom_optimizers().empty();
}
@@ -700,7 +703,7 @@
Status OptimizeGraph(
std::vector<string> ret_node_names, FunctionLibraryDefinition* flib,
const DeviceSet& device_set, Device* cpu_device,
- const ConfigProto& config_proto,
+ const ConfigProto& config_proto, const string& grappler_item_id,
const GrapplerItem::OptimizationOptions& optimization_options,
std::unique_ptr<tensorflow::Graph>* g) {
if (!tensorflow::grappler::MetaOptimizerEnabled(config_proto)) {
@@ -708,6 +711,7 @@
}
tensorflow::grappler::GrapplerItem item;
+ item.id = grappler_item_id;
item.optimization_options() = optimization_options;
// Add all available devices so that inlined function can be placed.
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.h b/tensorflow/core/grappler/optimizers/meta_optimizer.h
index ec78cc5..751a9e5 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.h
@@ -120,6 +120,7 @@
// `device_set`: the set of devices that graph can refer to.
// `cpu_device`: the CPU device.
// `config_proto`: Grapper configuration.
+// `grappler_item_id': Grappler item id (e.g. optimized function name).
// `optimization_options`: Grappler optimization constraints that are known only
// at runtime.
//
@@ -130,7 +131,7 @@
Status OptimizeGraph(
std::vector<string> ret_node_names, FunctionLibraryDefinition* lib,
const DeviceSet& device_set, Device* cpu_device,
- const ConfigProto& config_proto,
+ const ConfigProto& config_proto, const string& grappler_item_id,
const GrapplerItem::OptimizationOptions& optimization_options,
std::unique_ptr<tensorflow::Graph>* g);
diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc
index 357a0b3..35979d4 100644
--- a/tensorflow/core/grappler/utils/functions.cc
+++ b/tensorflow/core/grappler/utils/functions.cc
@@ -314,7 +314,7 @@
std::vector<string> keep_nodes, const int graph_def_version,
const bool is_stateful, GraphDef&& function_body)
: description_(std::move(description)),
- func_attr_(std::move(func_attr)),
+ func_attr_(func_attr),
input_arg_expansions_(std::move(input_arg_expansions)),
output_arg_expansions_(std::move(output_arg_expansions)),
is_stateful_(is_stateful) {
@@ -339,7 +339,7 @@
// Tensorflow functions execution semantics is different from the main graph,
// and we need to preserve it when we do graph optimizations.
- optimization_options().is_function_instantiation = true;
+ optimization_options().allow_pruning_stateful_and_dataset_ops = false;
}
const string& GrapplerFunctionItem::description() const { return description_; }
diff --git a/tensorflow/core/grappler/utils/functions_test.cc b/tensorflow/core/grappler/utils/functions_test.cc
index 7720888..30b6195 100644
--- a/tensorflow/core/grappler/utils/functions_test.cc
+++ b/tensorflow/core/grappler/utils/functions_test.cc
@@ -641,7 +641,9 @@
EXPECT_EQ(3, item.function_body().node_size());
EXPECT_EQ(1, item.input_size());
EXPECT_EQ(0, item.output_size());
- EXPECT_EQ(true, item.optimization_options().is_function_instantiation);
+
+ const auto &opts = item.optimization_options();
+ EXPECT_FALSE(opts.allow_pruning_stateful_and_dataset_ops);
}
TEST_F(FunctionsTest, MakeFunctionDef) {
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index ffb093b..ea651f1 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -5615,6 +5615,7 @@
"gemm_functors.h",
"image_resizer_state.h",
"initializable_lookup_table.h",
+ "logging_ops.h",
"lookup_table_init_op.h",
"lookup_table_op.h",
"lookup_util.h",
@@ -6560,6 +6561,30 @@
],
)
+tf_cc_test_mkl(
+ name = "mkl_quantized_concat_op_test",
+ size = "small",
+ srcs = ["mkl_quantized_concat_op_test.cc"],
+ deps = [
+ ":mkl_concat_op",
+ ":ops_testutil",
+ ":ops_util",
+ ":quantization_utils",
+ ":quantized_ops",
+ "//tensorflow/core:array_ops_op_lib",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:math_ops_op_lib",
+ "//tensorflow/core:mkl_array_ops_op_lib",
+ "//tensorflow/core:nn_ops_op_lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
tf_cc_test(
name = "quantized_batch_norm_op_test",
size = "small",
@@ -7064,6 +7089,7 @@
"//tensorflow/core/util/proto:descriptors",
"//tensorflow/core/util/proto:proto_utils",
"//third_party/eigen3",
+ "@com_google_absl//absl/container:flat_hash_map",
],
)
@@ -7145,3 +7171,12 @@
":cwise_lib",
],
)
+
+# Library to link with when compiling the quantize and dequantize kernels directly,
+# e.g. for selective registration.
+cc_header_only_library(
+ name = "quantize_and_dequantize_op_hdrs",
+ deps = [
+ ":quantize_and_dequantize_op",
+ ],
+)
diff --git a/tensorflow/core/kernels/collective_nccl_reducer.cc b/tensorflow/core/kernels/collective_nccl_reducer.cc
index 113f148..c5e6f06 100644
--- a/tensorflow/core/kernels/collective_nccl_reducer.cc
+++ b/tensorflow/core/kernels/collective_nccl_reducer.cc
@@ -149,7 +149,9 @@
<< col_params_->group.num_tasks << " current task "
<< col_params_->instance.task_names[col_params_->default_rank]
<< " num local devices " << num_local_devices
- << " num global devices " << num_global_devices;
+ << " num global devices " << num_global_devices << " device "
+ << col_ctx_->device_name << " instance "
+ << col_params_->instance.instance_key;
NcclManager::instance()->AddToAllReduce(
std::move(participant),
{nccl_collective_key, num_local_devices, num_global_devices,
diff --git a/tensorflow/core/kernels/collective_ops.cc b/tensorflow/core/kernels/collective_ops.cc
index 04e37e8..56843eb 100644
--- a/tensorflow/core/kernels/collective_ops.cc
+++ b/tensorflow/core/kernels/collective_ops.cc
@@ -43,6 +43,10 @@
// Call in a blockable thread because it's not guaranteed that
// this call cannot block.
c->env()->SchedClosure([this, c, done, col_exec]() {
+ VLOG(1) << "CollectiveOpKernel CompleteParams for collective "
+ << col_params_.name << " device " << c->device()->name()
+ << " group " << col_params_.group.group_key << " instance "
+ << col_params_.instance.instance_key;
col_exec->CompleteParamsAsync(
c->device()->name(), &col_params_, c->cancellation_manager(),
[this, c, done](const Status& s) {
@@ -149,10 +153,18 @@
col_params_.instance.shape = c->input(0).shape();
}
if (!CanProceedWithCompute(c, col_exec, done)) return;
- auto actual_done = [c, col_exec, done](const Status& s) {
+
+ int32 instance_key = col_params_.instance.instance_key;
+ auto actual_done = [c, instance_key, done](const Status& s) {
OP_REQUIRES_OK_ASYNC(c, s, done);
done();
+ VLOG(1) << "CollectiveReduceKernel ExecuteAsync done for device "
+ << c->device()->name() << " instance " << instance_key;
};
+ VLOG(1) << "CollectiveReduceKernel ExecuteAsync start for collective "
+ << col_params_.name << " device " << c->device()->name()
+ << " group " << col_params_.group.group_key << " instance "
+ << instance_key;
col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
}
@@ -211,10 +223,17 @@
" does not match shape of input"),
done);
- auto actual_done = [c, col_exec, done](const Status& s) {
+ int32 instance_key = col_params_.instance.instance_key;
+ auto actual_done = [c, instance_key, done](const Status& s) {
OP_REQUIRES_OK_ASYNC(c, s, done);
done();
+ VLOG(1) << "CollectiveBcastSendOpKernel ExecuteAsync done for device "
+ << c->device()->name() << " instance " << instance_key;
};
+ VLOG(1) << "CollectiveBcastSendOpKernel ExecuteAsync start for collective "
+ << col_params_.name << " device " << c->device()->name()
+ << " group " << col_params_.group.group_key << " instance "
+ << instance_key;
col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
}
@@ -266,10 +285,17 @@
}
if (!CanProceedWithCompute(c, col_exec, done)) return;
- auto actual_done = [c, col_exec, done](const Status& s) {
+ int32 instance_key = col_params_.instance.instance_key;
+ auto actual_done = [c, instance_key, done](const Status& s) {
OP_REQUIRES_OK_ASYNC(c, s, done);
done();
+ VLOG(1) << "CollectiveBcastRecvOpKernel ExecuteAsync done for device "
+ << c->device()->name() << " instance " << instance_key;
};
+ VLOG(1) << "CollectiveBcastRecvOpKernel ExecuteAsync start for collective "
+ << col_params_.name << " device " << c->device()->name()
+ << " group " << col_params_.group.group_key << " instance "
+ << instance_key;
col_exec->ExecuteAsync(c, col_params_, GetCollectiveKey(c), actual_done);
}
diff --git a/tensorflow/core/kernels/concat_lib_gpu.cc b/tensorflow/core/kernels/concat_lib_gpu.cc
index 161810d..853d7c3 100644
--- a/tensorflow/core/kernels/concat_lib_gpu.cc
+++ b/tensorflow/core/kernels/concat_lib_gpu.cc
@@ -117,6 +117,7 @@
TF_CALL_complex128(REGISTER);
TF_CALL_int32(REGISTER); // Needed for TensorLists.
TF_CALL_int64(REGISTER);
+TF_CALL_int16(REGISTER);
TF_CALL_bfloat16(REGISTER);
TF_CALL_bool(REGISTER);
TF_CALL_uint8(REGISTER);
diff --git a/tensorflow/core/kernels/concat_lib_gpu_impl.cu.cc b/tensorflow/core/kernels/concat_lib_gpu_impl.cu.cc
index 1a9adfa..ae828b5 100644
--- a/tensorflow/core/kernels/concat_lib_gpu_impl.cu.cc
+++ b/tensorflow/core/kernels/concat_lib_gpu_impl.cu.cc
@@ -203,6 +203,7 @@
TF_CALL_complex128(REGISTER_GPUCONCAT32);
TF_CALL_int32(REGISTER_GPUCONCAT32); // Needed for TensorLists.
TF_CALL_int64(REGISTER_GPUCONCAT32);
+TF_CALL_int16(REGISTER_GPUCONCAT32);
TF_CALL_uint8(REGISTER_GPUCONCAT32);
REGISTER_GPUCONCAT32(bfloat16);
REGISTER_GPUCONCAT32(bool);
@@ -212,6 +213,7 @@
TF_CALL_complex128(REGISTER_GPUCONCAT64);
TF_CALL_int32(REGISTER_GPUCONCAT64); // Needed for TensorLists.
TF_CALL_int64(REGISTER_GPUCONCAT64);
+TF_CALL_int16(REGISTER_GPUCONCAT64);
TF_CALL_uint8(REGISTER_GPUCONCAT64);
REGISTER_GPUCONCAT64(bfloat16);
REGISTER_GPUCONCAT64(bool);
@@ -221,6 +223,7 @@
TF_CALL_complex128(REGISTER_GPU32);
TF_CALL_int32(REGISTER_GPU32); // Needed for TensorLists.
TF_CALL_int64(REGISTER_GPU32);
+TF_CALL_int16(REGISTER_GPU32);
TF_CALL_uint8(REGISTER_GPU32);
REGISTER_GPU32(bfloat16);
REGISTER_GPU32(bool);
@@ -230,6 +233,7 @@
TF_CALL_complex128(REGISTER_GPU64);
TF_CALL_int32(REGISTER_GPU64); // Needed for TensorLists.
TF_CALL_int64(REGISTER_GPU64);
+TF_CALL_int16(REGISTER_GPU64);
TF_CALL_uint8(REGISTER_GPU64);
REGISTER_GPU64(bfloat16);
REGISTER_GPU64(bool);
diff --git a/tensorflow/core/kernels/conv_grad_ops.cc b/tensorflow/core/kernels/conv_grad_ops.cc
index 0fd7550..9ceb510 100644
--- a/tensorflow/core/kernels/conv_grad_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_ops.cc
@@ -78,9 +78,6 @@
" stride: ", dim->stride, " dilation: ", dim->dilation);
}
- // TODO(reedwm): Correctly handle explicit padding here. The rest of the
- // fields set on 'dim' are only used in XLA. TensorFlow ops do not yet support
- // explicit padding for XLA.
int64 effective_filter_size = (dim->filter_size - 1) * dim->dilation + 1;
dim->expanded_output_size = (dim->output_size - 1) * dim->stride + 1;
const auto padded_out_size = dim->input_size + effective_filter_size - 1;
@@ -102,7 +99,7 @@
StringPiece label, int num_spatial_dims, const TensorShape& input_shape,
const TensorShape& filter_shape, const TensorShape& out_backprop_shape,
const gtl::ArraySlice<int32>& dilations, const std::vector<int32>& strides,
- Padding padding, const std::vector<int64>& explicit_paddings,
+ Padding padding, absl::Span<const int64> explicit_paddings,
TensorFormat data_format, ConvBackpropDimensions* dims) {
// The + 2 in the following line is for the batch and feature dimensions.
const int num_dims = num_spatial_dims + 2;
diff --git a/tensorflow/core/kernels/conv_grad_ops.h b/tensorflow/core/kernels/conv_grad_ops.h
index c8e8cf2..173f928 100644
--- a/tensorflow/core/kernels/conv_grad_ops.h
+++ b/tensorflow/core/kernels/conv_grad_ops.h
@@ -222,7 +222,7 @@
int64 stride;
int64 dilation;
- // The following fields are valid only if the padding is not EXPLICIT.
+ // Output size after scaling by the stride.
int64 expanded_output_size;
// Number of padding elements to be added before/after this dimension of
@@ -270,7 +270,7 @@
StringPiece label, int num_spatial_dims, const TensorShape& input_shape,
const TensorShape& filter_shape, const TensorShape& out_backprop_shape,
const gtl::ArraySlice<int32>& dilations, const std::vector<int32>& strides,
- Padding padding, const std::vector<int64>& explicit_paddings,
+ Padding padding, absl::Span<const int64> explicit_paddings,
TensorFormat data_format, ConvBackpropDimensions* dims);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index 3cadb55..56c8399 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -9,8 +9,8 @@
load(
"//tensorflow:tensorflow.bzl",
- "tf_kernel_library",
"tf_cc_test",
+ "tf_kernel_library",
)
# TODO(mrry): Remove this empty forwarding library.
@@ -22,6 +22,27 @@
)
cc_library(
+ name = "dataset_test_base",
+ testonly = 1,
+ srcs = ["dataset_test_base.cc"],
+ hdrs = ["dataset_test_base.h"],
+ deps = [
+ ":dataset_utils",
+ ":iterator_ops",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:tensor_testutil",
+ "//tensorflow/core:test",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/kernels:ops_testutil",
+ ],
+)
+
+cc_library(
name = "dataset_utils",
srcs = ["dataset_utils.cc"],
hdrs = ["dataset_utils.h"],
@@ -123,6 +144,17 @@
)
tf_kernel_library(
+ name = "shard_dataset_op",
+ srcs = ["shard_dataset_op.cc"],
+ deps = [
+ "//tensorflow/core:dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+tf_kernel_library(
name = "window_dataset_op",
srcs = ["window_dataset_op.cc"],
deps = [
@@ -184,6 +216,28 @@
],
)
+tf_cc_test(
+ name = "map_dataset_op_test",
+ size = "small",
+ srcs = ["map_dataset_op_test.cc"],
+ deps = [
+ ":dataset_test_base",
+ ":dataset_utils",
+ ":iterator_ops",
+ ":map_dataset_op",
+ ":range_dataset_op",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/kernels:cwise_op",
+ "//tensorflow/core/kernels:function_ops",
+ ],
+)
+
cc_library(
name = "parallel_map_iterator",
srcs = ["parallel_map_iterator.cc"],
@@ -350,6 +404,23 @@
],
)
+tf_cc_test(
+ name = "range_dataset_op_test",
+ size = "small",
+ srcs = ["range_dataset_op_test.cc"],
+ deps = [
+ ":dataset_test_base",
+ ":dataset_utils",
+ ":iterator_ops",
+ ":range_dataset_op",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:ptr_util",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
tf_kernel_library(
name = "shuffle_dataset_op",
srcs = ["shuffle_dataset_op.cc"],
@@ -488,17 +559,14 @@
],
)
-tf_kernel_library(
- name = "optimize_dataset_op",
- srcs = ["optimize_dataset_op.cc"],
+cc_library(
+ name = "graph_rewrite_dataset",
+ srcs = ["graph_rewrite_dataset.cc"],
+ hdrs = ["graph_rewrite_dataset.h"],
deps = [
"//tensorflow/core:core_cpu_internal",
- "//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
- "//tensorflow/core/grappler:graph_view",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:grappler_item_builder",
"//tensorflow/core/grappler/clusters:virtual_cluster",
@@ -510,6 +578,19 @@
)
tf_kernel_library(
+ name = "optimize_dataset_op",
+ srcs = ["optimize_dataset_op.cc"],
+ deps = [
+ ":graph_rewrite_dataset",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+tf_kernel_library(
name = "model_dataset_op",
srcs = ["model_dataset_op.cc"],
deps = [
@@ -563,6 +644,7 @@
":range_dataset_op",
":reader_dataset_ops",
":repeat_dataset_op",
+ ":shard_dataset_op",
":shuffle_dataset_op",
":skip_dataset_op",
":sparse_tensor_slice_dataset_op",
diff --git a/tensorflow/core/kernels/data/dataset_test_base.cc b/tensorflow/core/kernels/data/dataset_test_base.cc
new file mode 100644
index 0000000..36a862f
--- /dev/null
+++ b/tensorflow/core/kernels/data/dataset_test_base.cc
@@ -0,0 +1,196 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/data/dataset_test_base.h"
+
+namespace tensorflow {
+namespace data {
+
+Status DatasetOpsTestBase::CreateOpKernel(
+ const NodeDef& node_def, std::unique_ptr<OpKernel>* op_kernel) {
+ Status status;
+ *op_kernel =
+ tensorflow::CreateOpKernel(device_type_, device_.get(), allocator_,
+ node_def, TF_GRAPH_DEF_VERSION, &status);
+ return status;
+}
+
+Status DatasetOpsTestBase::CreateDataset(OpKernel* kernel,
+ OpKernelContext* context,
+ DatasetBase** const dataset) {
+ TF_RETURN_IF_ERROR(RunOpKernel(kernel, context));
+ // Assume that DatasetOp has only one output.
+ DCHECK_EQ(context->num_outputs(), 1);
+ TF_RETURN_IF_ERROR(GetDatasetFromContext(context, 0, dataset));
+ return Status::OK();
+}
+
+Status DatasetOpsTestBase::CreateIteratorContext(
+ OpKernelContext* const op_context,
+ std::unique_ptr<IteratorContext>* iterator_context) {
+ IteratorContext::Params params(op_context);
+ function_handle_cache_ = absl::make_unique<FunctionHandleCache>(flr_);
+ params.function_handle_cache = function_handle_cache_.get();
+ *iterator_context = absl::make_unique<IteratorContext>(params);
+ return Status::OK();
+}
+
+Status DatasetOpsTestBase::GetDatasetFromContext(OpKernelContext* context,
+ int output_index,
+ DatasetBase** const dataset) {
+ Tensor* output = context->mutable_output(output_index);
+ Status status = GetDatasetFromVariantTensor(*output, dataset);
+ (*dataset)->Ref();
+ return status;
+}
+
+Status DatasetOpsTestBase::InitThreadPool(int thread_num) {
+ if (thread_num < 1) {
+ return errors::InvalidArgument(
+ "The `thread_num` argument should be positive but got: ", thread_num);
+ }
+ thread_pool_ = absl::make_unique<thread::ThreadPool>(
+ Env::Default(), ThreadOptions(), "inter_op", thread_num);
+ return Status::OK();
+}
+
+Status DatasetOpsTestBase::InitFunctionLibraryRuntime(
+ const std::vector<FunctionDef>& flib, int cpu_num) {
+ if (cpu_num < 1) {
+ return errors::InvalidArgument(
+ "The `cpu_num` argument should be positive but got: ", cpu_num);
+ }
+ SessionOptions options;
+ auto* device_count = options.config.mutable_device_count();
+ device_count->insert({"CPU", cpu_num});
+ std::vector<std::unique_ptr<Device>> devices;
+ TF_RETURN_IF_ERROR(DeviceFactory::AddDevices(
+ options, "/job:localhost/replica:0/task:0", &devices));
+ device_mgr_ = absl::make_unique<DeviceMgr>(std::move(devices));
+
+ FunctionDefLibrary proto;
+ for (const auto& fdef : flib) *(proto.add_function()) = fdef;
+ lib_def_ =
+ absl::make_unique<FunctionLibraryDefinition>(OpRegistry::Global(), proto);
+
+ OptimizerOptions opts;
+ pflr_ = absl::make_unique<ProcessFunctionLibraryRuntime>(
+ device_mgr_.get(), Env::Default(), TF_GRAPH_DEF_VERSION, lib_def_.get(),
+ opts, thread_pool_.get(), nullptr /* cluster_flr */);
+ flr_ = pflr_->GetFLR("/job:localhost/replica:0/task:0/cpu:0");
+ if (thread_pool_ == nullptr) {
+ runner_ = [](std::function<void()> fn) { fn(); };
+ } else {
+ runner_ = [this](std::function<void()> fn) {
+ thread_pool_->Schedule(std::move(fn));
+ };
+ }
+ return Status::OK();
+}
+
+Status DatasetOpsTestBase::RunOpKernel(OpKernel* op_kernel,
+ OpKernelContext* context) {
+ device_->Compute(op_kernel, context);
+ return context->status();
+}
+
+Status DatasetOpsTestBase::CreateOpKernelContext(
+ OpKernel* kernel, gtl::InlinedVector<TensorValue, 4>* inputs,
+ std::unique_ptr<OpKernelContext>* context) {
+ params_ = absl::make_unique<OpKernelContext::Params>();
+ params_->device = device_.get();
+ params_->resource_manager = device_->resource_manager();
+ params_->frame_iter = FrameAndIter(0, 0);
+ params_->inputs = inputs;
+ params_->op_kernel = kernel;
+ params_->function_library = flr_;
+ params_->runner = &runner_;
+ step_container_ =
+ absl::make_unique<ScopedStepContainer>(0, [](const string&) {});
+ params_->step_container = step_container_.get();
+ checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper;
+ slice_reader_cache_ =
+ absl::make_unique<checkpoint::TensorSliceReaderCacheWrapper>();
+ params_->slice_reader_cache = slice_reader_cache_.get();
+
+ // Set the allocator attributes for the outputs.
+ allocator_attrs_.clear();
+ for (int index = 0; index < params_->op_kernel->num_outputs(); index++) {
+ AllocatorAttributes attr;
+ const bool on_host =
+ (params_->op_kernel->output_memory_types()[index] == HOST_MEMORY);
+ attr.set_on_host(on_host);
+ allocator_attrs_.emplace_back(attr);
+ }
+ params_->output_attr_array = gtl::vector_as_array(&allocator_attrs_);
+
+ *context = absl::make_unique<OpKernelContext>(params_.get());
+ return Status::OK();
+}
+
+Status DatasetOpsTestBase::CreateSerializationContext(
+ std::unique_ptr<SerializationContext>* context) {
+ SerializationContext::Params params;
+ params.flib_def = lib_def_.get();
+ *context = absl::make_unique<SerializationContext>(params);
+ return Status::OK();
+}
+
+Status DatasetOpsTestBase::CheckOpKernelInput(
+ const OpKernel& kernel, const gtl::InlinedVector<TensorValue, 4>& inputs) {
+ if (kernel.input_types().size() != inputs.size()) {
+ return errors::Internal("The number of input elements should be ",
+ kernel.input_types().size(),
+ ", but got: ", inputs.size());
+ }
+ return Status::OK();
+}
+
+Status DatasetOpsTestBase::AddDatasetInput(
+ gtl::InlinedVector<TensorValue, 4>* inputs, DataTypeVector input_types,
+ DataType dtype, const TensorShape& shape) {
+ if (input_types.size() < inputs->size()) {
+ return errors::InvalidArgument("Adding more inputs than types: ",
+ inputs->size(), " vs. ", input_types.size());
+ }
+ bool is_ref = IsRefType(input_types[inputs->size()]);
+ std::unique_ptr<Tensor> input =
+ absl::make_unique<Tensor>(allocator_, dtype, shape);
+
+ if (is_ref) {
+ DataType expected_dtype = RemoveRefType(input_types[inputs->size()]);
+ if (expected_dtype != dtype) {
+ return errors::InvalidArgument("The input data type is ", dtype,
+ " , but expected: ", expected_dtype);
+ }
+ inputs->push_back({&lock_for_refs_, input.get()});
+ } else {
+ if (input_types[inputs->size()] != dtype) {
+ return errors::InvalidArgument(
+ "The input data type is ", dtype,
+ " , but expected: ", input_types[inputs->size()]);
+ }
+ inputs->push_back({nullptr, input.get()});
+ }
+
+ // TODO(jsimsa): Figure out how to avoid using a member variable to garbage
+ // collect the inputs.
+ tensors_.push_back(std::move(input));
+
+ return Status::OK();
+}
+
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/dataset_test_base.h b/tensorflow/core/kernels/data/dataset_test_base.h
new file mode 100644
index 0000000..ee50083
--- /dev/null
+++ b/tensorflow/core/kernels/data/dataset_test_base.h
@@ -0,0 +1,180 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_DATA_DATASET_TEST_BASE_H_
+#define TENSORFLOW_CORE_KERNELS_DATA_DATASET_TEST_BASE_H_
+
+#include <vector>
+
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/function_handle_cache.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/variant.h"
+#include "tensorflow/core/framework/variant_tensor_data.h"
+#include "tensorflow/core/kernels/data/dataset_utils.h"
+#include "tensorflow/core/kernels/data/iterator_ops.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/ptr_util.h"
+
+namespace tensorflow {
+namespace data {
+
+// Helpful functions to test Dataset op kernels.
+class DatasetOpsTestBase : public ::testing::Test {
+ public:
+ DatasetOpsTestBase()
+ : device_(DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")),
+ device_type_(DEVICE_CPU) {
+ allocator_ = device_->GetAllocator(AllocatorAttributes());
+ }
+
+ ~DatasetOpsTestBase() {}
+
+ // Creates a new op kernel based on the node definition.
+ Status CreateOpKernel(const NodeDef& node_def,
+ std::unique_ptr<OpKernel>* op_kernel);
+
+ // Creates a new dataset.
+ Status CreateDataset(OpKernel* kernel, OpKernelContext* context,
+ DatasetBase** const dataset);
+
+ // Creates a new RangeDataset op kernel. `T` specifies the output dtype of the
+ // op kernel.
+ template <typename T>
+ Status CreateRangeDatasetOpKernel(
+ StringPiece node_name, std::unique_ptr<OpKernel>* range_op_kernel) {
+ DataTypeVector dtypes({tensorflow::DataTypeToEnum<T>::value});
+ std::vector<PartialTensorShape> shapes({{}});
+ NodeDef node_def = test::function::NDef(
+ node_name, "RangeDataset", {"start", "stop", "step"},
+ {{"output_types", dtypes}, {"output_shapes", shapes}});
+
+ TF_RETURN_IF_ERROR(CreateOpKernel(node_def, range_op_kernel));
+ return Status::OK();
+ }
+
+ // Creates a new RangeDataset dataset. `T` specifies the output dtype of the
+ // RangeDataset op kernel.
+ template <typename T>
+ Status CreateRangeDataset(int64 start, int64 end, int64 step,
+ StringPiece node_name,
+ DatasetBase** range_dataset) {
+ std::unique_ptr<OpKernel> range_kernel;
+ TF_RETURN_IF_ERROR(CreateRangeDatasetOpKernel<T>(node_name, &range_kernel));
+ gtl::InlinedVector<TensorValue, 4> range_inputs;
+ TF_RETURN_IF_ERROR(AddDatasetInputFromArray<int64>(
+ &range_inputs, range_kernel->input_types(), TensorShape({}), {start}));
+ TF_RETURN_IF_ERROR(AddDatasetInputFromArray<int64>(
+ &range_inputs, range_kernel->input_types(), TensorShape({}), {end}));
+ TF_RETURN_IF_ERROR(AddDatasetInputFromArray<int64>(
+ &range_inputs, range_kernel->input_types(), TensorShape({}), {step}));
+ std::unique_ptr<OpKernelContext> range_context;
+ TF_RETURN_IF_ERROR(CreateOpKernelContext(range_kernel.get(), &range_inputs,
+ &range_context));
+ TF_RETURN_IF_ERROR(CheckOpKernelInput(*range_kernel, range_inputs));
+ TF_RETURN_IF_ERROR(RunOpKernel(range_kernel.get(), range_context.get()));
+ TF_RETURN_IF_ERROR(
+ GetDatasetFromContext(range_context.get(), 0, range_dataset));
+ return Status::OK();
+ }
+
+ // Fetches the dataset from the operation context.
+ Status GetDatasetFromContext(OpKernelContext* context, int output_index,
+ DatasetBase** const dataset);
+
+ protected:
+ // Creates a thread pool for parallel tasks.
+ Status InitThreadPool(int thread_num);
+
+ // Initializes the runtime for computing the dataset operation and registers
+ // the input function definitions. `InitThreadPool()' needs to be called
+ // before this method if we want to run the tasks in parallel.
+ Status InitFunctionLibraryRuntime(const std::vector<FunctionDef>& flib,
+ int cpu_num);
+
+ // Runs an operation producing outputs.
+ Status RunOpKernel(OpKernel* op_kernel, OpKernelContext* context);
+
+ // Checks that the size of `inputs` matches the requirement of the op kernel.
+ Status CheckOpKernelInput(const OpKernel& kernel,
+ const gtl::InlinedVector<TensorValue, 4>& inputs);
+
+ // Creates a new context for running the dataset operation.
+ Status CreateOpKernelContext(OpKernel* kernel,
+ gtl::InlinedVector<TensorValue, 4>* inputs,
+ std::unique_ptr<OpKernelContext>* context);
+
+ // Creates a new iterator context for iterating the dataset.
+ Status CreateIteratorContext(
+ OpKernelContext* const op_context,
+ std::unique_ptr<IteratorContext>* iterator_context);
+
+ // Creates a new serialization context for serializing the dataset and
+ // iterator.
+ Status CreateSerializationContext(
+ std::unique_ptr<SerializationContext>* context);
+
+ // Adds an arrayslice of data into the input vector. `input_types` describes
+ // the required data type for each input tensor. `shape` and `data` describes
+ // the shape and values of the current input tensor. `T` specifies the dtype
+ // of the input data.
+ template <typename T>
+ Status AddDatasetInputFromArray(gtl::InlinedVector<TensorValue, 4>* inputs,
+ DataTypeVector input_types,
+ const TensorShape& shape,
+ const gtl::ArraySlice<T>& data) {
+ TF_RETURN_IF_ERROR(
+ AddDatasetInput(inputs, input_types, DataTypeToEnum<T>::v(), shape));
+ test::FillValues<T>(inputs->back().tensor, data);
+ return Status::OK();
+ }
+
+ private:
+ // Adds an empty tensor with the specified dtype and shape to the input
+ // vector.
+ Status AddDatasetInput(gtl::InlinedVector<TensorValue, 4>* inputs,
+ DataTypeVector input_types, DataType dtype,
+ const TensorShape& shape);
+
+ protected:
+ std::unique_ptr<Device> device_;
+ DeviceType device_type_;
+ Allocator* allocator_; // Owned by `AllocatorFactoryRegistry`.
+ std::vector<AllocatorAttributes> allocator_attrs_;
+ std::unique_ptr<ScopedStepContainer> step_container_;
+
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
+ FunctionLibraryRuntime* flr_; // Owned by `pflr_`.
+ std::unique_ptr<FunctionHandleCache> function_handle_cache_;
+ std::function<void(std::function<void()>)> runner_;
+ std::unique_ptr<DeviceMgr> device_mgr_;
+ std::unique_ptr<FunctionLibraryDefinition> lib_def_;
+ std::unique_ptr<OpKernelContext::Params> params_;
+ std::unique_ptr<checkpoint::TensorSliceReaderCacheWrapper>
+ slice_reader_cache_;
+ std::unique_ptr<thread::ThreadPool> thread_pool_;
+ std::vector<std::unique_ptr<Tensor>> tensors_; // Owns tensors.
+ mutex lock_for_refs_; // Used as the Mutex for inputs added as refs.
+};
+
+} // namespace data
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_DATA_DATASET_TEST_BASE_H_
diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD
index 4f7c8f1..9171b91 100644
--- a/tensorflow/core/kernels/data/experimental/BUILD
+++ b/tensorflow/core/kernels/data/experimental/BUILD
@@ -233,6 +233,21 @@
)
tf_kernel_library(
+ name = "rebatch_dataset_op",
+ srcs = ["rebatch_dataset_op.cc"],
+ deps = [
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/grappler/optimizers/data:rebatch",
+ "//tensorflow/core/kernels/data:graph_rewrite_dataset",
+ ],
+)
+
+tf_kernel_library(
name = "scan_dataset_op",
srcs = ["scan_dataset_op.cc"],
deps = [
@@ -391,6 +406,7 @@
":parse_example_dataset_op",
":prefetching_kernels",
":random_dataset_op",
+ ":rebatch_dataset_op",
":scan_dataset_op",
":set_stats_aggregator_dataset_op",
":sleep_dataset_op",
diff --git a/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc
new file mode 100644
index 0000000..a95773a
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc
@@ -0,0 +1,92 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/kernels/data/graph_rewrite_dataset.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+constexpr char kOptimizerName[] = "tf_data_rebatcher";
+
+class RebatchDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit RebatchDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx),
+ graph_def_version_(ctx->graph_def_version()) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ }
+
+ protected:
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ int64 num_workers;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_workers", &num_workers));
+ OP_REQUIRES(ctx, num_workers > 0,
+ errors::InvalidArgument(
+ "num_parallel_calls must be greater than zero."));
+
+ Dataset* dataset =
+ new Dataset(ctx, input, num_workers, output_types_, output_shapes_);
+ Status s = dataset->Optimize(ctx);
+ if (s.ok()) {
+ *output = dataset;
+ } else {
+ dataset->Unref();
+ OP_REQUIRES_OK(ctx, s);
+ }
+ }
+
+ private:
+ class Dataset : public GraphRewriteDataset {
+ public:
+ Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ const int64 num_workers, const DataTypeVector& output_types,
+ const std::vector<PartialTensorShape>& output_shapes)
+ : GraphRewriteDataset(ctx, input, output_types, output_shapes),
+ num_workers_(num_workers) {}
+
+ string DebugString() const override { return "RebatchDatasetOp::Dataset"; }
+
+ private:
+ RewriterConfig CreateGrapplerRewriteConfig() override {
+ RewriterConfig rewriter_config;
+ rewriter_config.add_optimizers(kOptimizerName);
+ rewriter_config.set_meta_optimizer_iterations(
+ RewriterConfig_NumIterationsType_ONE);
+ auto custom_optimizer = rewriter_config.add_custom_optimizers();
+ custom_optimizer->set_name(kOptimizerName);
+ AttrValue num_workers_attr;
+ num_workers_attr.set_i(num_workers_);
+ (*custom_optimizer->mutable_parameter_map())["num_workers"] =
+ num_workers_attr;
+ return rewriter_config;
+ }
+
+ const int64 num_workers_;
+ };
+
+ const int graph_def_version_;
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("ExperimentalRebatchDataset").Device(DEVICE_CPU),
+ RebatchDatasetOp);
+
+} // anonymous namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/stats_dataset_ops.cc b/tensorflow/core/kernels/data/experimental/stats_dataset_ops.cc
index bf96be4..be5fa4c 100644
--- a/tensorflow/core/kernels/data/experimental/stats_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/experimental/stats_dataset_ops.cc
@@ -12,8 +12,6 @@
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/example/example.pb.h"
-#include "tensorflow/core/example/feature.pb.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/stats_aggregator.h"
diff --git a/tensorflow/core/kernels/data/graph_rewrite_dataset.cc b/tensorflow/core/kernels/data/graph_rewrite_dataset.cc
new file mode 100644
index 0000000..bc4bb46
--- /dev/null
+++ b/tensorflow/core/kernels/data/graph_rewrite_dataset.cc
@@ -0,0 +1,239 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/kernels/data/graph_rewrite_dataset.h"
+
+#include "tensorflow/core/protobuf/meta_graph.pb.h"
+#include "tensorflow/core/protobuf/rewriter_config.pb.h"
+
+namespace tensorflow {
+namespace data {
+
+GraphRewriteDataset::~GraphRewriteDataset() {
+ input_->Unref();
+ if (optimized_input_) {
+ optimized_input_->Unref();
+ }
+}
+
+Status GraphRewriteDataset::Optimize(OpKernelContext* ctx) {
+ GraphDefBuilder b;
+ DatasetGraphDefBuilder db(&b);
+ Node* input_node = nullptr;
+ SerializationContext::Params params;
+ std::vector<std::pair<string, Tensor>> input_list;
+ params.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
+ params.input_list = &input_list;
+ params.optimization_only = true;
+ SerializationContext serialization_ctx(params);
+ TF_RETURN_IF_ERROR(
+ db.AddInputDataset(&serialization_ctx, input_, &input_node));
+ string output_node = input_node->name();
+
+ GraphDef graph_def;
+ TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
+ VLOG(3) << "Before optimization: " << graph_def.DebugString();
+
+ TF_RETURN_IF_ERROR(ApplyOptimizations(ctx, &graph_def, &output_node));
+ VLOG(3) << "After optimization: " << graph_def.DebugString();
+
+ // Instantiate the optimized input pipeline by running the optimized graph
+ // using the optimized function library.
+ TF_RETURN_IF_ERROR(ctx->function_library()->Clone(&flib_def_, &pflr_, &lib_));
+
+ // Create a FunctionHandleCache.
+ function_handle_cache_ = absl::make_unique<FunctionHandleCache>(lib_);
+
+ // Some functions may have been modified without having their names
+ // changed (for example, nested dataset graphs from FlatMap or
+ // Interleave). To avoid name conflicts, we remove these functions from
+ // flib_def_ before adding the optimized function library.
+ for (const FunctionDef& fd : graph_def.library().function()) {
+ if (flib_def_->Find(fd.signature().name()) != nullptr) {
+ TF_RETURN_IF_ERROR(flib_def_->RemoveFunction(fd.signature().name()));
+ }
+ }
+ TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph_def.library()));
+
+ Graph graph(OpRegistry::Global());
+ TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
+ std::vector<Tensor> outputs;
+ GraphRunner graph_runner(ctx->function_library()->device());
+
+ TF_RETURN_IF_ERROR(
+ graph_runner.Run(&graph, lib_, input_list, {output_node}, &outputs));
+ TF_RETURN_IF_ERROR(
+ GetDatasetFromVariantTensor(outputs[0], &optimized_input_));
+ optimized_input_->Ref();
+ return Status::OK();
+}
+
+Status GraphRewriteDataset::AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const {
+ // We only serialize the optimized dataset to avoid re-running
+ // optimizations when the input pipeline is restored from a checkpoint.
+ TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, optimized_input_, output));
+ return Status::OK();
+}
+
+namespace {
+void AddFakeSinks(FunctionDef* function_def) {
+ int counter = 0;
+ for (const auto& output : function_def->signature().output_arg()) {
+ NodeDef* node = function_def->add_node_def();
+ tensorflow::grappler::function_utils::SetUniqueFunctionNodeName(
+ strings::StrCat("FakeSink", counter++), function_def, node);
+ node->set_op("Identity");
+ node->add_input(function_def->ret().at(output.name()));
+ (*node->mutable_attr())["T"].set_type(output.type());
+
+ (*function_def->mutable_ret())[output.name()] =
+ strings::StrCat(node->name(), ":output:0");
+ }
+}
+
+void RemoveFakeSinks(FunctionDef* function_def) {
+ // Map from identity node names to their input tensor strings
+ std::map<string, string> identity_map;
+ for (const auto& node : function_def->node_def()) {
+ if (node.op() == "Identity" && node.input_size() == 1) {
+ identity_map[node.name()] = node.input(0);
+ }
+ }
+ for (const auto& output_arg : function_def->signature().output_arg()) {
+ const string& tensor = function_def->ret().at(output_arg.name());
+ const string& output_node = tensor.substr(0, tensor.find(':'));
+ if (identity_map.find(output_node) != identity_map.end()) {
+ (*function_def->mutable_ret())[output_arg.name()] =
+ identity_map.at(output_node);
+ }
+ }
+}
+} // anonymous namespace
+
+Status GraphRewriteDataset::ApplyOptimizations(OpKernelContext* ctx,
+ GraphDef* graph_def,
+ string* output_node) {
+ // Add an identity node as the fetch node, otherwise we might get
+ // 'placeholder is both fed and fetched' errors in some cases when using
+ // input list with placeholder dataset nodes.
+ NodeDef* node = graph_def->mutable_node()->Add();
+ tensorflow::grappler::graph_utils::SetUniqueGraphNodeName("Sink", graph_def,
+ node);
+ node->set_op("Identity");
+ node->add_input(*output_node);
+ (*node->mutable_attr())["T"].set_type(DT_VARIANT);
+ *output_node = node->name();
+
+ // Add fake sink node to graph and functions to allow rewriting the actual
+ // sink nodes.
+ // TODO(b/118820916): When MetaOptimizer adds provisions for function
+ // retvals to be optimizable, we will no longer need this.
+ for (auto& function_def : *graph_def->mutable_library()->mutable_function()) {
+ AddFakeSinks(&function_def);
+ }
+
+ // Create metagraph.
+ MetaGraphDef meta_graph_def;
+ (*meta_graph_def.mutable_graph_def()) = *graph_def;
+
+ // Grappler determines fetch ops from collection 'train_op'.
+ CollectionDef collection_def;
+ auto node_list = collection_def.mutable_node_list();
+ node_list->add_value(*output_node);
+ (*meta_graph_def.mutable_collection_def())["train_op"] = collection_def;
+
+ // Create Grappler item.
+ tensorflow::grappler::ItemConfig item_config;
+ item_config.apply_optimizations = true;
+ std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
+ tensorflow::grappler::GrapplerItemFromMetaGraphDef(
+ "graph", meta_graph_def, item_config);
+ std::unordered_map<string, tensorflow::DeviceProperties> device_map;
+ tensorflow::grappler::VirtualCluster cluster(device_map);
+
+ // Run data optimizer using grappler's meta optimizer.
+ tensorflow::ConfigProto config;
+ *config.mutable_graph_options()->mutable_rewrite_options() =
+ CreateGrapplerRewriteConfig();
+ TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer(
+ *grappler_item, config, ctx->device(), &cluster, graph_def));
+
+ // Remove fake sinks after optimizations are done.
+ // TODO(b/118820916): When MetaOptimizer adds provisions for function
+ // retvals to be optimizable, we will no longer need this.
+ for (auto& function_def : *graph_def->mutable_library()->mutable_function()) {
+ RemoveFakeSinks(&function_def);
+ }
+
+ return Status::OK();
+}
+
+class GraphRewriteDataset::Iterator
+ : public DatasetIterator<GraphRewriteDataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<GraphRewriteDataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ IteratorContext::Params params(ctx);
+ params.lib = dataset()->lib_;
+ params.function_handle_cache = dataset()->function_handle_cache_.get();
+ return dataset()->optimized_input_->MakeIterator(
+ IteratorContext(std::move(params)), prefix(), &input_impl_);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ IteratorContext::Params params(ctx);
+ params.lib = dataset()->lib_;
+ params.function_handle_cache = dataset()->function_handle_cache_.get();
+ return input_impl_->GetNext(IteratorContext(std::move(params)), out_tensors,
+ end_of_sequence);
+ }
+
+ protected:
+ std::shared_ptr<model::Node> CreateNode(
+ IteratorContext* ctx, model::Node::Args args) const override {
+ return model::MakeKnownRatioNode(std::move(args),
+ /*ratio=*/1);
+ }
+
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
+ return Status::OK();
+ }
+
+ private:
+ std::unique_ptr<IteratorBase> input_impl_;
+};
+
+std::unique_ptr<IteratorBase> GraphRewriteDataset::MakeIteratorInternal(
+ const string& prefix) const {
+ // We do not add a token for this dataset to the prefix. The
+ // prefix is used to identify checkpoint elements and since this
+ // dataset is excluded from the checkpoint, adding a token
+ // here would result in invalid checkpoint identifiers.
+ return absl::make_unique<Iterator>(Iterator::Params{this, prefix});
+}
+
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/graph_rewrite_dataset.h b/tensorflow/core/kernels/data/graph_rewrite_dataset.h
new file mode 100644
index 0000000..dedbdce
--- /dev/null
+++ b/tensorflow/core/kernels/data/graph_rewrite_dataset.h
@@ -0,0 +1,92 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_DATA_GRAPH_REWRITE_DATASET_H_
+#define TENSORFLOW_CORE_KERNELS_DATA_GRAPH_REWRITE_DATASET_H_
+
+#include "tensorflow/core/common_runtime/graph_runner.h"
+#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/function_handle_cache.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/grappler_item_builder.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
+
+namespace tensorflow {
+namespace data {
+
+class GraphRewriteDataset : public DatasetBase {
+ public:
+ GraphRewriteDataset(OpKernelContext* ctx, const DatasetBase* input,
+ const DataTypeVector& output_types,
+ const std::vector<PartialTensorShape>& output_shapes)
+ : DatasetBase(DatasetContext(ctx)),
+ optimized_input_(nullptr),
+ input_(input),
+ output_types_(output_types),
+ output_shapes_(output_shapes) {
+ input_->Ref();
+ }
+
+ ~GraphRewriteDataset() override;
+
+ // Runs Grappler to transform the input dataset into optimized_input_
+ // dataset.
+ Status Optimize(OpKernelContext* ctx);
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override;
+
+ const DataTypeVector& output_dtypes() const override { return output_types_; }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return output_shapes_;
+ }
+
+ int64 Cardinality() const override { return input_->Cardinality(); }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override;
+
+ private:
+ class Iterator;
+
+ // Create a Grappler RewriteConfig proto that defines the list of
+ // optimizations to be run by the Grappler Meta Optimizer.
+ virtual RewriterConfig CreateGrapplerRewriteConfig() = 0;
+
+ Status ApplyOptimizations(OpKernelContext* ctx, GraphDef* graph_def,
+ string* output_node);
+
+ DatasetBase* optimized_input_;
+ FunctionLibraryRuntime* lib_ = nullptr;
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_ = nullptr;
+ std::unique_ptr<FunctionLibraryDefinition> flib_def_ = nullptr;
+ std::unique_ptr<FunctionHandleCache> function_handle_cache_ = nullptr;
+ const DatasetBase* input_;
+ const DataTypeVector output_types_;
+ const std::vector<PartialTensorShape> output_shapes_;
+};
+
+} // namespace data
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_DATA_GRAPH_REWRITE_DATASET_H_
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index 808f834..0d2dfd9 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -967,78 +967,58 @@
}
}
-namespace {
+void IteratorGetNextAsOptionalOp::ComputeAsync(OpKernelContext* ctx,
+ DoneCallback done) {
+ IteratorResource* iterator;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done);
+ // The call to `iterator->GetNext()` may block and depend on an
+ // inter-op thread pool thread, so we issue the call from the
+ // owned thread pool.
+ background_worker_.Schedule(std::bind(
+ [this, ctx, iterator](DoneCallback done) {
+ std::vector<Tensor> components;
+ bool end_of_sequence = false;
-class IteratorGetNextAsOptionalOp : public AsyncOpKernel {
- public:
- explicit IteratorGetNextAsOptionalOp(OpKernelConstruction* ctx)
- : AsyncOpKernel(ctx),
- background_worker_(ctx->env(),
- "tf_data_iterator_get_next_as_optional") {
- OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
- }
+ Status s = iterator->GetNext(IteratorContext(ctx), &components,
+ &end_of_sequence);
+ // NOTE(mrry): We must unref the iterator before calling `done()`, to
+ // avoid destruction races.
+ iterator->Unref();
- void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
- IteratorResource* iterator;
- OP_REQUIRES_OK_ASYNC(
- ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done);
- // The call to `iterator->GetNext()` may block and depend on an
- // inter-op thread pool thread, so we issue the call from the
- // owned thread pool.
- background_worker_.Schedule(std::bind(
- [this, ctx, iterator](DoneCallback done) {
- std::vector<Tensor> components;
- bool end_of_sequence = false;
-
- Status s = iterator->GetNext(IteratorContext(ctx), &components,
- &end_of_sequence);
- // NOTE(mrry): We must unref the iterator before calling `done()`, to
- // avoid destruction races.
- iterator->Unref();
-
- if (!s.ok()) {
- ctx->SetStatus(s);
- } else if (end_of_sequence) {
- OP_REQUIRES_OK_ASYNC(ctx, WriteOptionalNoneToOutput(ctx, 0), done);
- } else {
- for (int i = 0; i < components.size(); ++i) {
- OP_REQUIRES_ASYNC(
- ctx, components[i].dtype() == output_types_[i],
- errors::InvalidArgument(
- "The given optional does not match the expected type for "
- "component ",
- i, ". Expected: ", DataTypeString(output_types_[i]),
- ". Actual: ", DataTypeString(components[i].dtype()), "."),
- done);
- OP_REQUIRES_ASYNC(
- ctx,
- output_shapes_[i].IsCompatibleWith(components[i].shape()),
- errors::InvalidArgument(
- "The given optional does not match the expected shape "
- "for component ",
- i, ". Expected: ", output_shapes_[i].DebugString(),
- ". Actual: ", components[i].shape().DebugString(), "."),
- done);
- }
-
- OP_REQUIRES_OK_ASYNC(
- ctx,
- WriteOptionalWithValueToOutput(ctx, 0, std::move(components)),
+ if (!s.ok()) {
+ ctx->SetStatus(s);
+ } else if (end_of_sequence) {
+ OP_REQUIRES_OK_ASYNC(ctx, WriteOptionalNoneToOutput(ctx, 0), done);
+ } else {
+ for (int i = 0; i < components.size(); ++i) {
+ OP_REQUIRES_ASYNC(
+ ctx, components[i].dtype() == output_types_[i],
+ errors::InvalidArgument(
+ "The given optional does not match the expected type for "
+ "component ",
+ i, ". Expected: ", DataTypeString(output_types_[i]),
+ ". Actual: ", DataTypeString(components[i].dtype()), "."),
+ done);
+ OP_REQUIRES_ASYNC(
+ ctx, output_shapes_[i].IsCompatibleWith(components[i].shape()),
+ errors::InvalidArgument(
+ "The given optional does not match the expected shape "
+ "for component ",
+ i, ". Expected: ", output_shapes_[i].DebugString(),
+ ". Actual: ", components[i].shape().DebugString(), "."),
done);
}
- done();
- },
- std::move(done)));
- }
- private:
- BackgroundWorker background_worker_;
- DataTypeVector output_types_;
- std::vector<PartialTensorShape> output_shapes_;
-};
-
-} // namespace
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ WriteOptionalWithValueToOutput(ctx, 0, std::move(components)),
+ done);
+ }
+ done();
+ },
+ std::move(done)));
+}
void IteratorToStringHandleOp::Compute(OpKernelContext* ctx) {
const Tensor& resource_handle_t = ctx->input(0);
diff --git a/tensorflow/core/kernels/data/iterator_ops.h b/tensorflow/core/kernels/data/iterator_ops.h
index cd72269..7d769d3 100644
--- a/tensorflow/core/kernels/data/iterator_ops.h
+++ b/tensorflow/core/kernels/data/iterator_ops.h
@@ -19,6 +19,8 @@
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/ops_util.h"
namespace tensorflow {
@@ -115,6 +117,24 @@
BackgroundWorker background_worker_;
};
+class IteratorGetNextAsOptionalOp : public AsyncOpKernel {
+ public:
+ explicit IteratorGetNextAsOptionalOp(OpKernelConstruction* ctx)
+ : AsyncOpKernel(ctx),
+ background_worker_(ctx->env(),
+ "tf_data_iterator_get_next_as_optional") {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ }
+
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override;
+
+ private:
+ BackgroundWorker background_worker_;
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+};
+
class IteratorGetNextSyncOp : public OpKernel {
public:
explicit IteratorGetNextSyncOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc
index 95f4c1c..e516d77 100644
--- a/tensorflow/core/kernels/data/map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_dataset_op.cc
@@ -138,7 +138,6 @@
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
diff --git a/tensorflow/core/kernels/data/map_dataset_op_test.cc b/tensorflow/core/kernels/data/map_dataset_op_test.cc
new file mode 100644
index 0000000..f9c1cf4
--- /dev/null
+++ b/tensorflow/core/kernels/data/map_dataset_op_test.cc
@@ -0,0 +1,534 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/function_handle_cache.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/variant.h"
+#include "tensorflow/core/framework/variant_tensor_data.h"
+#include "tensorflow/core/kernels/data/dataset_test_base.h"
+#include "tensorflow/core/kernels/data/dataset_utils.h"
+#include "tensorflow/core/kernels/data/iterator_ops.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/util/ptr_util.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+constexpr char kNodeName[] = "map_dataset";
+constexpr char kOpName[] = "MapDataset";
+
+class MapDatasetOpTest : public DatasetOpsTestBase {
+ protected:
+ // Creates a new MapDataset op kernel. The `input_dataset` parameter should be
+ // same with the node name of the input dataset for the method
+ // `CreateMapDatasetContext()`. `T` specifies the output dtype of MapDataset.
+ template <typename T>
+ Status CreateMapDatasetOpKernel(const string& input_dataset,
+ const string& func_name,
+ std::unique_ptr<OpKernel>* map_kernel) {
+ FunctionDefHelper::AttrValueWrapper func =
+ FunctionDefHelper::FunctionRef(func_name, {{"T", DT_INT64}});
+
+ map_node_def_ = test::function::NDef(
+ kNodeName, kOpName, {input_dataset},
+ {{"f", func},
+ {"Targuments", {}},
+ {"output_shapes", gtl::ArraySlice<TensorShape>{{}}},
+ {"output_types",
+ gtl::ArraySlice<DataType>{tensorflow::DataTypeToEnum<T>::value}},
+ {"use_inter_op_parallelism", true},
+ {"preserve_cardinality", false}});
+ TF_CHECK_OK(CreateOpKernel(map_node_def_, map_kernel));
+ return Status::OK();
+ }
+
+ // Creates a new MapDataset op kernel context.
+ Status CreateMapDatasetContext(
+ DatasetBase* const input_dataset, OpKernel* const map_kernel,
+ std::unique_ptr<OpKernelContext>* map_context) {
+ map_inputs_.clear();
+ // Save the input dataset into a variant tensor as the input of MapDataset.
+ Tensor dataset_tensor(DT_VARIANT, TensorShape({}));
+ TF_RETURN_IF_ERROR(
+ StoreDatasetInVariantTensor(input_dataset, &dataset_tensor));
+ Variant variant = dataset_tensor.scalar<Variant>()();
+ TF_RETURN_IF_ERROR(AddDatasetInputFromArray<Variant>(
+ &map_inputs_, map_kernel->input_types(), TensorShape({}), {variant}));
+ input_dataset->Ref();
+ TF_RETURN_IF_ERROR(
+ CreateOpKernelContext(map_kernel, &map_inputs_, map_context));
+ TF_RETURN_IF_ERROR(CheckOpKernelInput(*map_kernel, map_inputs_));
+ return Status::OK();
+ }
+
+ private:
+ NodeDef map_node_def_;
+ gtl::InlinedVector<TensorValue, 4> map_inputs_;
+};
+
+struct GetNextTestParams {
+ explicit GetNextTestParams(int64 input_start, int64 input_end,
+ int64 input_step, string input_func_name,
+ std::vector<int64> input_expected_values,
+ std::vector<FunctionDef> input_func_lib)
+ : start(input_start),
+ end(input_end),
+ step(input_step),
+ func_name(std::move(input_func_name)),
+ expected_values(std::move(input_expected_values)),
+ func_lib(std::move(input_func_lib)) {}
+
+ int64 start;
+ int64 end;
+ int64 step;
+ string func_name;
+ std::vector<int64> expected_values;
+ std::vector<FunctionDef> func_lib;
+};
+
+struct DatasetGetNextTest : MapDatasetOpTest,
+ ::testing::WithParamInterface<GetNextTestParams> {};
+
+TEST_P(DatasetGetNextTest, GetNext) {
+ int thread_num = 2, cpu_num = 2;
+ GetNextTestParams test_params = GetParam();
+
+ TF_ASSERT_OK(InitThreadPool(thread_num));
+ TF_ASSERT_OK(InitFunctionLibraryRuntime(test_params.func_lib, cpu_num));
+
+ DatasetBase* range_dataset;
+ TF_ASSERT_OK(CreateRangeDataset<int64>(test_params.start, test_params.end,
+ test_params.step, "range",
+ &range_dataset));
+ core::ScopedUnref scored_unref_range_dataset(range_dataset);
+
+ std::unique_ptr<OpKernel> map_kernel;
+ TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
+ range_dataset->name(), test_params.func_name, &map_kernel));
+ std::unique_ptr<OpKernelContext> map_context;
+ TF_ASSERT_OK(
+ CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context));
+ DatasetBase* map_dataset;
+ TF_ASSERT_OK(
+ CreateDataset(map_kernel.get(), map_context.get(), &map_dataset));
+ core::ScopedUnref scored_unref_map_dataset(map_dataset);
+
+ std::unique_ptr<IteratorContext> iterator_context;
+ TF_ASSERT_OK(CreateIteratorContext(map_context.get(), &iterator_context));
+ std::unique_ptr<IteratorBase> iterator;
+ TF_ASSERT_OK(
+ map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator));
+ bool end_of_sequence = false;
+ std::vector<Tensor> out_tensors;
+ while (!end_of_sequence) {
+ TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
+ &end_of_sequence));
+ }
+
+ EXPECT_EQ(out_tensors.size(), test_params.expected_values.size());
+ for (size_t i = 0; i < out_tensors.size(); ++i) {
+ int64 actual_value = out_tensors[i].flat<int64>()(0);
+ int64 expect_value = test_params.expected_values[i];
+ EXPECT_EQ(actual_value, expect_value);
+ }
+}
+
+INSTANTIATE_TEST_CASE_P(
+ MapDatasetOpTest, DatasetGetNextTest,
+ ::testing::Values(
+ GetNextTestParams(
+ 0, 10, 3, "XTimesTwo", std::vector<int64>{0, 6, 12, 18},
+ std::vector<FunctionDef>{test::function::XTimesTwo()}),
+ GetNextTestParams(0, 10, 3, "XAddX", std::vector<int64>{0, 6, 12, 18},
+ std::vector<FunctionDef>{test::function::XAddX()}),
+ GetNextTestParams(
+ 10, 0, -3, "XTimesFour", std::vector<int64>{40, 28, 16, 4},
+ std::vector<FunctionDef>{test::function::XTimesTwo(),
+ test::function::XTimesFour()})));
+
+TEST_F(MapDatasetOpTest, DatasetName) {
+ int thread_num = 2, cpu_num = 2;
+ int64 start = 0, end = 10, step = 1;
+ FunctionDef func_def = test::function::XTimesTwo();
+
+ TF_ASSERT_OK(InitThreadPool(thread_num));
+ TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num));
+
+ DatasetBase* range_dataset;
+ TF_ASSERT_OK(
+ CreateRangeDataset<int64>(start, end, step, "range", &range_dataset));
+ core::ScopedUnref scored_unref_range_dataset(range_dataset);
+
+ std::unique_ptr<OpKernel> map_kernel;
+ TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
+ range_dataset->name(), func_def.signature().name(), &map_kernel));
+ std::unique_ptr<OpKernelContext> map_context;
+ TF_ASSERT_OK(
+ CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context));
+ DatasetBase* map_dataset;
+ TF_ASSERT_OK(
+ CreateDataset(map_kernel.get(), map_context.get(), &map_dataset));
+ core::ScopedUnref scored_unref_map_dataset(map_dataset);
+
+ EXPECT_EQ(map_dataset->name(), kOpName);
+}
+
+TEST_F(MapDatasetOpTest, DatasetOutputDtypes) {
+ int thread_num = 2, cpu_num = 2;
+ int64 start = 0, end = 10, step = 1;
+ FunctionDef func_def = test::function::XTimesTwo();
+
+ TF_ASSERT_OK(InitThreadPool(thread_num));
+ TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num));
+
+ DatasetBase* range_dataset;
+ TF_ASSERT_OK(
+ CreateRangeDataset<int64>(start, end, step, "range", &range_dataset));
+ core::ScopedUnref scored_unref_range_dataset(range_dataset);
+
+ std::unique_ptr<OpKernel> map_kernel;
+ TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
+ range_dataset->name(), func_def.signature().name(), &map_kernel));
+ std::unique_ptr<OpKernelContext> map_context;
+ TF_ASSERT_OK(
+ CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context));
+ DatasetBase* map_dataset;
+ TF_ASSERT_OK(
+ CreateDataset(map_kernel.get(), map_context.get(), &map_dataset));
+ core::ScopedUnref scored_unref_map_dataset(map_dataset);
+
+ DataTypeVector expected_dtypes({DT_INT64});
+ EXPECT_EQ(map_dataset->output_dtypes(), expected_dtypes);
+}
+
+TEST_F(MapDatasetOpTest, DatasetOutputShapes) {
+ int thread_num = 2, cpu_num = 2;
+ int64 start = 0, end = 10, step = 1;
+ FunctionDef func_def = test::function::XTimesTwo();
+
+ TF_ASSERT_OK(InitThreadPool(thread_num));
+ TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num));
+
+ DatasetBase* range_dataset;
+ TF_ASSERT_OK(
+ CreateRangeDataset<int64>(start, end, step, "range", &range_dataset));
+ core::ScopedUnref scored_unref_range_dataset(range_dataset);
+
+ std::unique_ptr<OpKernel> map_kernel;
+ TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
+ range_dataset->name(), func_def.signature().name(), &map_kernel));
+ std::unique_ptr<OpKernelContext> map_context;
+ TF_ASSERT_OK(
+ CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context));
+ DatasetBase* map_dataset;
+ TF_ASSERT_OK(
+ CreateDataset(map_kernel.get(), map_context.get(), &map_dataset));
+ core::ScopedUnref scored_unref_map_dataset(map_dataset);
+
+ std::vector<PartialTensorShape> expected_shapes({PartialTensorShape({})});
+ EXPECT_EQ(map_dataset->output_shapes().size(), expected_shapes.size());
+ for (int i = 0; i < map_dataset->output_shapes().size(); ++i) {
+ EXPECT_TRUE(
+ map_dataset->output_shapes()[i].IsIdenticalTo(expected_shapes[i]));
+ }
+}
+
+struct CardinalityTestParams {
+ explicit CardinalityTestParams(int64 input_start, int64 input_end,
+ int64 input_step,
+ int input_expected_cardinality)
+ : start(input_start),
+ end(input_end),
+ step(input_step),
+ expected_cardinality(input_expected_cardinality) {}
+
+ int64 start;
+ int64 end;
+ int64 step;
+ int expected_cardinality;
+};
+
+struct DatasetCardinalityTest
+ : MapDatasetOpTest,
+ ::testing::WithParamInterface<CardinalityTestParams> {};
+
+TEST_P(DatasetCardinalityTest, Cardinality) {
+ int thread_num = 2, cpu_num = 2;
+ CardinalityTestParams test_params = GetParam();
+ FunctionDef func_def = test::function::XTimesTwo();
+
+ TF_ASSERT_OK(InitThreadPool(thread_num));
+ TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num));
+
+ DatasetBase* range_dataset;
+ TF_ASSERT_OK(CreateRangeDataset<int64>(test_params.start, test_params.end,
+ test_params.step, "range",
+ &range_dataset));
+ core::ScopedUnref scored_unref_range_dataset(range_dataset);
+
+ std::unique_ptr<OpKernel> map_kernel;
+ TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
+ range_dataset->name(), func_def.signature().name(), &map_kernel));
+ std::unique_ptr<OpKernelContext> map_context;
+ TF_ASSERT_OK(
+ CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context));
+ DatasetBase* map_dataset;
+ TF_ASSERT_OK(
+ CreateDataset(map_kernel.get(), map_context.get(), &map_dataset));
+ core::ScopedUnref scored_unref_map_dataset(map_dataset);
+
+ EXPECT_EQ(map_dataset->Cardinality(), test_params.expected_cardinality);
+}
+
+INSTANTIATE_TEST_CASE_P(MapDatasetOpTest, DatasetCardinalityTest,
+ ::testing::Values(CardinalityTestParams(0, 10, 1, 10),
+ CardinalityTestParams(0, 10, 3, 4),
+ CardinalityTestParams(10, 0, -3, 4)));
+
+TEST_F(MapDatasetOpTest, DatasetSave) {
+ int thread_num = 2, cpu_num = 2;
+ int64 start = 0, end = 10, step = 1;
+ FunctionDef func_def = test::function::XTimesTwo();
+
+ TF_ASSERT_OK(InitThreadPool(thread_num));
+ TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num));
+
+ DatasetBase* range_dataset;
+ TF_ASSERT_OK(
+ CreateRangeDataset<int64>(start, end, step, "range", &range_dataset));
+ core::ScopedUnref scored_unref_range_dataset(range_dataset);
+
+ std::unique_ptr<OpKernel> map_kernel;
+ TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
+ range_dataset->name(), func_def.signature().name(), &map_kernel));
+ std::unique_ptr<OpKernelContext> map_context;
+ TF_ASSERT_OK(
+ CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context));
+ DatasetBase* map_dataset;
+ TF_ASSERT_OK(
+ CreateDataset(map_kernel.get(), map_context.get(), &map_dataset));
+ core::ScopedUnref scored_unref_map_dataset(map_dataset);
+
+ std::unique_ptr<SerializationContext> serialization_context;
+ TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
+ VariantTensorData data;
+ VariantTensorDataWriter writer(&data);
+ TF_ASSERT_OK(map_dataset->Save(serialization_context.get(), &writer));
+ TF_ASSERT_OK(writer.Flush());
+}
+
+TEST_F(MapDatasetOpTest, IteratorOutputDtypes) {
+ int64 start = 0, end = 10, step = 1;
+ int thread_num = 2, cpu_num = 2;
+ FunctionDef func_def = test::function::XTimesTwo();
+
+ TF_ASSERT_OK(InitThreadPool(thread_num));
+ TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num));
+
+ DatasetBase* range_dataset;
+ TF_ASSERT_OK(
+ CreateRangeDataset<int64>(start, end, step, "range", &range_dataset));
+ core::ScopedUnref scored_unref_range_dataset(range_dataset);
+
+ std::unique_ptr<OpKernel> map_kernel;
+ TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
+ range_dataset->name(), func_def.signature().name(), &map_kernel));
+ std::unique_ptr<OpKernelContext> map_context;
+ TF_ASSERT_OK(
+ CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context));
+ DatasetBase* map_dataset;
+ TF_ASSERT_OK(
+ CreateDataset(map_kernel.get(), map_context.get(), &map_dataset));
+ core::ScopedUnref scored_unref_map_dataset(map_dataset);
+
+ std::unique_ptr<IteratorContext> iterator_context;
+ TF_ASSERT_OK(CreateIteratorContext(map_context.get(), &iterator_context));
+ std::unique_ptr<IteratorBase> iterator;
+ TF_ASSERT_OK(
+ map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator));
+ DataTypeVector expected_dtypes({DT_INT64});
+ EXPECT_EQ(iterator->output_dtypes(), expected_dtypes);
+}
+
+TEST_F(MapDatasetOpTest, IteratorOutputShapes) {
+ int64 start = 0, end = 10, step = 1;
+ int thread_num = 2, cpu_num = 2;
+ FunctionDef func_def = test::function::XTimesTwo();
+
+ TF_ASSERT_OK(InitThreadPool(thread_num));
+ TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num));
+
+ DatasetBase* range_dataset;
+ TF_ASSERT_OK(
+ CreateRangeDataset<int64>(start, end, step, "range", &range_dataset));
+ core::ScopedUnref scored_unref_range_dataset(range_dataset);
+
+ std::unique_ptr<OpKernel> map_kernel;
+ TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
+ range_dataset->name(), func_def.signature().name(), &map_kernel));
+ std::unique_ptr<OpKernelContext> map_context;
+ TF_ASSERT_OK(
+ CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context));
+ DatasetBase* map_dataset;
+ TF_ASSERT_OK(
+ CreateDataset(map_kernel.get(), map_context.get(), &map_dataset));
+ core::ScopedUnref scored_unref_map_dataset(map_dataset);
+
+ std::unique_ptr<IteratorContext> iterator_context;
+ TF_ASSERT_OK(CreateIteratorContext(map_context.get(), &iterator_context));
+ std::unique_ptr<IteratorBase> iterator;
+ TF_ASSERT_OK(
+ map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator));
+
+ std::vector<PartialTensorShape> expected_shapes({PartialTensorShape({})});
+ EXPECT_EQ(iterator->output_shapes().size(), expected_shapes.size());
+ for (int i = 0; i < map_dataset->output_shapes().size(); ++i) {
+ EXPECT_TRUE(iterator->output_shapes()[i].IsIdenticalTo(expected_shapes[i]));
+ }
+}
+
+TEST_F(MapDatasetOpTest, IteratorOutputPrefix) {
+ int64 start = 0, end = 10, step = 1;
+ int thread_num = 2, cpu_num = 2;
+ FunctionDef func_def = test::function::XTimesTwo();
+
+ TF_ASSERT_OK(InitThreadPool(thread_num));
+ TF_ASSERT_OK(InitFunctionLibraryRuntime({func_def}, cpu_num));
+
+ DatasetBase* range_dataset;
+ TF_ASSERT_OK(
+ CreateRangeDataset<int64>(start, end, step, "range", &range_dataset));
+ core::ScopedUnref scored_unref_range_dataset(range_dataset);
+
+ std::unique_ptr<OpKernel> map_kernel;
+ TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
+ range_dataset->name(), func_def.signature().name(), &map_kernel));
+ std::unique_ptr<OpKernelContext> map_context;
+ TF_ASSERT_OK(
+ CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context));
+ DatasetBase* map_dataset;
+ TF_ASSERT_OK(
+ CreateDataset(map_kernel.get(), map_context.get(), &map_dataset));
+ core::ScopedUnref scored_unref_map_dataset(map_dataset);
+
+ std::unique_ptr<IteratorContext> iterator_context;
+ TF_ASSERT_OK(CreateIteratorContext(map_context.get(), &iterator_context));
+ std::unique_ptr<IteratorBase> iterator;
+ TF_ASSERT_OK(
+ map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator));
+
+ EXPECT_EQ(iterator->prefix(), "Iterator::Map");
+}
+
+struct RoundtripTestParams {
+ explicit RoundtripTestParams(int64 input_start, int64 input_end,
+ int64 input_step, int input_breakpoint,
+ int64 input_expected_value,
+ string input_func_name,
+ std::vector<FunctionDef> input_func_lib)
+ : start(input_start),
+ end(input_end),
+ step(input_step),
+ breakpoint(input_breakpoint),
+ expected_value(input_expected_value),
+ func_name(std::move(input_func_name)),
+ func_lib(std::move(input_func_lib)) {}
+
+ int64 start;
+ int64 end;
+ int64 step;
+ int breakpoint;
+ int64 expected_value;
+ string func_name;
+ std::vector<FunctionDef> func_lib;
+};
+
+struct IteratorRoundtripTest
+ : MapDatasetOpTest,
+ ::testing::WithParamInterface<RoundtripTestParams> {};
+
+TEST_P(IteratorRoundtripTest, Roundtrip) {
+ int thread_num = 2, cpu_num = 2;
+ RoundtripTestParams test_params = GetParam();
+
+ TF_ASSERT_OK(InitThreadPool(thread_num));
+ TF_ASSERT_OK(InitFunctionLibraryRuntime(test_params.func_lib, cpu_num));
+
+ DatasetBase* range_dataset;
+ TF_ASSERT_OK(CreateRangeDataset<int64>(test_params.start, test_params.end,
+ test_params.step, "range",
+ &range_dataset));
+ core::ScopedUnref scored_unref_range_dataset(range_dataset);
+
+ std::unique_ptr<OpKernel> map_kernel;
+ TF_ASSERT_OK(CreateMapDatasetOpKernel<int64>(
+ range_dataset->name(), test_params.func_name, &map_kernel));
+ std::unique_ptr<OpKernelContext> map_context;
+ TF_ASSERT_OK(
+ CreateMapDatasetContext(range_dataset, map_kernel.get(), &map_context));
+ DatasetBase* map_dataset;
+ TF_ASSERT_OK(
+ CreateDataset(map_kernel.get(), map_context.get(), &map_dataset));
+ core::ScopedUnref scored_unref_map_dataset(map_dataset);
+
+ std::unique_ptr<IteratorContext> iterator_context;
+ TF_ASSERT_OK(CreateIteratorContext(map_context.get(), &iterator_context));
+ std::unique_ptr<IteratorBase> iterator;
+ TF_ASSERT_OK(
+ map_dataset->MakeIterator(iterator_context.get(), "Iterator", &iterator));
+
+ std::vector<Tensor> out_tensors;
+ bool end_of_sequence = false;
+ for (int i = 0; i < test_params.breakpoint; i++) {
+ TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
+ &end_of_sequence));
+ }
+
+ std::unique_ptr<SerializationContext> serialization_context;
+ TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
+ VariantTensorData data;
+ VariantTensorDataWriter writer(&data);
+ TF_ASSERT_OK(iterator->Save(serialization_context.get(), &writer));
+ TF_ASSERT_OK(writer.Flush());
+ VariantTensorDataReader reader(&data);
+ TF_ASSERT_OK(iterator->Restore(iterator_context.get(), &reader));
+ TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
+ &end_of_sequence));
+ EXPECT_EQ(out_tensors.back().flat<int64>()(0), test_params.expected_value);
+}
+
+INSTANTIATE_TEST_CASE_P(
+ MapDatasetOpTest, IteratorRoundtripTest,
+ ::testing::Values(RoundtripTestParams(0, 10, 2, 0, 0, "XTimesTwo",
+ std::vector<FunctionDef>{
+ test::function::XTimesTwo()}),
+ RoundtripTestParams(0, 10, 2, 4, 16, "XAddX",
+ std::vector<FunctionDef>{
+ test::function::XAddX()}),
+ RoundtripTestParams(0, 10, 2, 6, 32, "XTimesFour",
+ std::vector<FunctionDef>{
+ test::function::XTimesTwo(),
+ test::function::XTimesFour()})));
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc
index 6047dc5..17094e3 100644
--- a/tensorflow/core/kernels/data/optimize_dataset_op.cc
+++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc
@@ -14,26 +14,11 @@
==============================================================================*/
#include <map>
-#include "tensorflow/core/common_runtime/device_mgr.h"
-#include "tensorflow/core/common_runtime/graph_runner.h"
-#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/framework/dataset.h"
-#include "tensorflow/core/framework/device_base.h"
-#include "tensorflow/core/framework/function_handle_cache.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/graph/graph_constructor.h"
-#include "tensorflow/core/graph/graph_def_builder.h"
-#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
-#include "tensorflow/core/grappler/graph_view.h"
-#include "tensorflow/core/grappler/grappler_item.h"
-#include "tensorflow/core/grappler/grappler_item_builder.h"
-#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
-#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
-#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
-#include "tensorflow/core/lib/core/refcount.h"
+#include "tensorflow/core/kernels/data/graph_rewrite_dataset.h"
#include "tensorflow/core/lib/random/random.h"
-#include "tensorflow/core/protobuf/meta_graph.pb.h"
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace tensorflow {
@@ -71,235 +56,20 @@
}
private:
- class Dataset : public DatasetBase {
+ class Dataset : public GraphRewriteDataset {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
const std::vector<string>& optimizations,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
- : DatasetBase(DatasetContext(ctx)),
- optimized_input_(nullptr),
- input_(input),
- optimizations_(optimizations),
- output_types_(output_types),
- output_shapes_(output_shapes) {
- input_->Ref();
- }
-
- ~Dataset() override {
- input_->Unref();
- if (optimized_input_) {
- optimized_input_->Unref();
- }
- }
-
- std::unique_ptr<IteratorBase> MakeIteratorInternal(
- const string& prefix) const override {
- // We do not add a token for the optimization dataset to the prefix. The
- // prefix is used to identify checkpoint elements and since the
- // optimization dataset is excluded from the checkpoint, adding a token
- // here would result in invalid checkpoint identifiers.
- return absl::make_unique<Iterator>(Iterator::Params{this, prefix});
- }
-
- Status Optimize(OpKernelContext* ctx) {
- GraphDefBuilder b;
- DatasetGraphDefBuilder db(&b);
- Node* input_node = nullptr;
- SerializationContext::Params params;
- std::vector<std::pair<string, Tensor>> input_list;
- params.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
- params.input_list = &input_list;
- params.optimization_only = true;
- SerializationContext serialization_ctx(params);
- TF_RETURN_IF_ERROR(
- db.AddInputDataset(&serialization_ctx, input_, &input_node));
- string output_node = input_node->name();
-
- GraphDef graph_def;
- TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
- VLOG(3) << "Before optimization: " << graph_def.DebugString();
-
- TF_RETURN_IF_ERROR(ApplyOptimizations(ctx, &graph_def, &output_node));
- VLOG(3) << "After optimization: " << graph_def.DebugString();
-
- // Instantiate the optimized input pipeline by running the optimized graph
- // using the optimized function library.
- TF_RETURN_IF_ERROR(
- ctx->function_library()->Clone(&flib_def_, &pflr_, &lib_));
-
- // Create a FunctionHandleCache.
- function_handle_cache_ = absl::make_unique<FunctionHandleCache>(lib_);
-
- // Some functions may have been modified without having their names
- // changed (for example, nested dataset graphs from FlatMap or
- // Interleave). To avoid name conflicts, we remove these functions from
- // flib_def_ before adding the optimized function library.
- for (const FunctionDef& fd : graph_def.library().function()) {
- if (flib_def_->Find(fd.signature().name()) != nullptr) {
- TF_RETURN_IF_ERROR(flib_def_->RemoveFunction(fd.signature().name()));
- }
- }
- TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph_def.library()));
-
- Graph graph(OpRegistry::Global());
- TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
- std::vector<Tensor> outputs;
- GraphRunner graph_runner(ctx->function_library()->device());
-
- TF_RETURN_IF_ERROR(
- graph_runner.Run(&graph, lib_, input_list, {output_node}, &outputs));
- TF_RETURN_IF_ERROR(
- GetDatasetFromVariantTensor(outputs[0], &optimized_input_));
- optimized_input_->Ref();
- return Status::OK();
- }
-
- const DataTypeVector& output_dtypes() const override {
- return output_types_;
- }
- const std::vector<PartialTensorShape>& output_shapes() const override {
- return output_shapes_;
- }
+ : GraphRewriteDataset(ctx, input, output_types, output_shapes),
+ optimizations_(optimizations) {}
string DebugString() const override { return "OptimizeDatasetOp::Dataset"; }
- int64 Cardinality() const override { return input_->Cardinality(); }
-
- protected:
- Status AsGraphDefInternal(SerializationContext* ctx,
- DatasetGraphDefBuilder* b,
- Node** output) const override {
- // We only serialize the optimized dataset to avoid re-running
- // optimizations when the input pipeline is restored from a checkpoint.
- TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, optimized_input_, output));
- return Status::OK();
- }
-
private:
- class Iterator : public DatasetIterator<Dataset> {
- public:
- explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params) {}
-
- Status Initialize(IteratorContext* ctx) override {
- IteratorContext::Params params(ctx);
- params.lib = dataset()->lib_;
- params.function_handle_cache = dataset()->function_handle_cache_.get();
- return dataset()->optimized_input_->MakeIterator(
- IteratorContext(std::move(params)), prefix(), &input_impl_);
- }
-
- Status GetNextInternal(IteratorContext* ctx,
- std::vector<Tensor>* out_tensors,
- bool* end_of_sequence) override {
- IteratorContext::Params params(ctx);
- params.lib = dataset()->lib_;
- params.function_handle_cache = dataset()->function_handle_cache_.get();
- return input_impl_->GetNext(IteratorContext(std::move(params)),
- out_tensors, end_of_sequence);
- }
-
- protected:
- std::shared_ptr<model::Node> CreateNode(
- IteratorContext* ctx, model::Node::Args args) const override {
- return model::MakeKnownRatioNode(std::move(args),
- /*ratio=*/1);
- }
-
- Status SaveInternal(IteratorStateWriter* writer) override {
- TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
- return Status::OK();
- }
-
- Status RestoreInternal(IteratorContext* ctx,
- IteratorStateReader* reader) override {
- TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
- return Status::OK();
- }
-
- private:
- std::unique_ptr<IteratorBase> input_impl_;
- };
-
- void AddFakeSinks(FunctionDef* function_def) {
- int counter = 0;
- for (const auto& output : function_def->signature().output_arg()) {
- NodeDef* node = function_def->add_node_def();
- tensorflow::grappler::function_utils::SetUniqueFunctionNodeName(
- strings::StrCat("FakeSink", counter++), function_def, node);
- node->set_op("Identity");
- node->add_input(function_def->ret().at(output.name()));
- (*node->mutable_attr())["T"].set_type(output.type());
-
- (*function_def->mutable_ret())[output.name()] =
- strings::StrCat(node->name(), ":output:0");
- }
- }
-
- void RemoveFakeSinks(FunctionDef* function_def) {
- // Map from identity node names to their input tensor strings
- std::map<string, string> identity_map;
- for (const auto& node : function_def->node_def()) {
- if (node.op() == "Identity" && node.input_size() == 1) {
- identity_map[node.name()] = node.input(0);
- }
- }
- for (const auto& output_arg : function_def->signature().output_arg()) {
- const string& tensor = function_def->ret().at(output_arg.name());
- const string& output_node = tensor.substr(0, tensor.find(':'));
- if (identity_map.find(output_node) != identity_map.end()) {
- (*function_def->mutable_ret())[output_arg.name()] =
- identity_map.at(output_node);
- }
- }
- }
-
- Status ApplyOptimizations(OpKernelContext* ctx, GraphDef* graph_def,
- string* output_node) {
- // Add an identity node as the fetch node, otherwise we might get
- // 'placeholder is both fed and fetched' errors in some cases when using
- // input list with placeholder dataset nodes.
- NodeDef* node = graph_def->mutable_node()->Add();
- tensorflow::grappler::graph_utils::SetUniqueGraphNodeName(
- "Sink", graph_def, node);
- node->set_op("Identity");
- node->add_input(*output_node);
- (*node->mutable_attr())["T"].set_type(DT_VARIANT);
- *output_node = node->name();
-
- // Add fake sink node to graph and functions to allow rewriting the actual
- // sink nodes.
- // TODO(b/118820916): When MetaOptimizer adds provisions for function
- // retvals to be optimizable, we will no longer need this.
- for (auto& function_def :
- *graph_def->mutable_library()->mutable_function()) {
- AddFakeSinks(&function_def);
- }
-
- // Create metagraph.
- MetaGraphDef meta_graph_def;
- (*meta_graph_def.mutable_graph_def()) = *graph_def;
-
- // Grappler determines fetch ops from collection 'train_op'.
- CollectionDef collection_def;
- auto node_list = collection_def.mutable_node_list();
- node_list->add_value(*output_node);
- (*meta_graph_def.mutable_collection_def())["train_op"] = collection_def;
-
- // Create Grappler item.
- tensorflow::grappler::ItemConfig item_config;
- item_config.apply_optimizations = true;
- std::unique_ptr<tensorflow::grappler::GrapplerItem> grappler_item =
- tensorflow::grappler::GrapplerItemFromMetaGraphDef(
- "graph", meta_graph_def, item_config);
- std::unordered_map<string, tensorflow::DeviceProperties> device_map;
- tensorflow::grappler::VirtualCluster cluster(device_map);
-
- // Run data optimizer using grappler's meta optimizer.
- tensorflow::ConfigProto config;
- RewriterConfig& rewriter_config =
- *config.mutable_graph_options()->mutable_rewrite_options();
+ RewriterConfig CreateGrapplerRewriteConfig() override {
+ RewriterConfig rewriter_config;
rewriter_config.add_optimizers(kOptimizerName);
rewriter_config.set_meta_optimizer_iterations(
RewriterConfig_NumIterationsType_ONE);
@@ -311,30 +81,10 @@
for (const auto& opt : optimizations_) {
custom_optimizations_list->add_s(opt);
}
-
- TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer(
- *grappler_item, config, ctx->device(), &cluster, graph_def));
-
- // Remove fake sinks after optimizations are done.
- // TODO(b/118820916): When MetaOptimizer adds provisions for function
- // retvals to be optimizable, we will no longer need this.
- for (auto& function_def :
- *graph_def->mutable_library()->mutable_function()) {
- RemoveFakeSinks(&function_def);
- }
-
- return Status::OK();
+ return rewriter_config;
}
- DatasetBase* optimized_input_;
- FunctionLibraryRuntime* lib_ = nullptr;
- std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_ = nullptr;
- std::unique_ptr<FunctionLibraryDefinition> flib_def_ = nullptr;
- std::unique_ptr<FunctionHandleCache> function_handle_cache_ = nullptr;
- const DatasetBase* input_;
const std::vector<string> optimizations_;
- const DataTypeVector output_types_;
- const std::vector<PartialTensorShape> output_shapes_;
};
const int graph_def_version_;
diff --git a/tensorflow/core/kernels/data/range_dataset_op.cc b/tensorflow/core/kernels/data/range_dataset_op.cc
index aa14d27..87390ad 100644
--- a/tensorflow/core/kernels/data/range_dataset_op.cc
+++ b/tensorflow/core/kernels/data/range_dataset_op.cc
@@ -64,7 +64,7 @@
const std::vector<PartialTensorShape>& output_shapes() const override {
static std::vector<PartialTensorShape>* shapes =
- new std::vector<PartialTensorShape>({{}});
+ new std::vector<PartialTensorShape>({PartialTensorShape({})});
return *shapes;
}
diff --git a/tensorflow/core/kernels/data/range_dataset_op_test.cc b/tensorflow/core/kernels/data/range_dataset_op_test.cc
new file mode 100644
index 0000000..0bbc09a
--- /dev/null
+++ b/tensorflow/core/kernels/data/range_dataset_op_test.cc
@@ -0,0 +1,421 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/variant.h"
+#include "tensorflow/core/framework/variant_tensor_data.h"
+#include "tensorflow/core/kernels/data/dataset_test_base.h"
+#include "tensorflow/core/kernels/data/dataset_utils.h"
+#include "tensorflow/core/kernels/data/iterator_ops.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/util/ptr_util.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+constexpr char kOpName[] = "RangeDataset";
+
+class RangeDatasetOpTest : public DatasetOpsTestBase {
+ protected:
+ // Creates a new RangeDataset op kernel context.
+ Status CreateRangeDatasetContext(
+ int64 start, int64 end, int64 step, OpKernel* const range_kernel,
+ std::unique_ptr<OpKernelContext>* range_context) {
+ inputs_.clear();
+ TF_RETURN_IF_ERROR(AddDatasetInputFromArray<int64>(
+ &inputs_, range_kernel->input_types(), TensorShape({}), {start}));
+ TF_RETURN_IF_ERROR(AddDatasetInputFromArray<int64>(
+ &inputs_, range_kernel->input_types(), TensorShape({}), {end}));
+ TF_RETURN_IF_ERROR(AddDatasetInputFromArray<int64>(
+ &inputs_, range_kernel->input_types(), TensorShape({}), {step}));
+
+ TF_RETURN_IF_ERROR(
+ CreateOpKernelContext(range_kernel, &inputs_, range_context));
+ TF_RETURN_IF_ERROR(CheckOpKernelInput(*range_kernel, inputs_));
+ return Status::OK();
+ }
+
+ private:
+ gtl::InlinedVector<TensorValue, 4> inputs_;
+};
+
+struct GetNextTestParams {
+ explicit GetNextTestParams(int64 input_start, int64 input_end,
+ int64 input_step)
+ : start(input_start), end(input_end), step(input_step) {}
+
+ int64 start;
+ int64 end;
+ int64 step;
+};
+
+struct DatasetGetNextTest : RangeDatasetOpTest,
+ ::testing::WithParamInterface<GetNextTestParams> {};
+
+TEST_P(DatasetGetNextTest, GetNext) {
+ int thread_num = 2, cpu_num = 2;
+ GetNextTestParams params = GetParam();
+
+ TF_ASSERT_OK(InitThreadPool(thread_num));
+ TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
+
+ std::unique_ptr<OpKernel> range_kernel;
+ TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>("range", &range_kernel));
+ std::unique_ptr<OpKernelContext> range_context;
+ TF_ASSERT_OK(CreateRangeDatasetContext(params.start, params.end, params.step,
+ range_kernel.get(), &range_context));
+ DatasetBase* range_dataset;
+ TF_ASSERT_OK(
+ CreateDataset(range_kernel.get(), range_context.get(), &range_dataset));
+ core::ScopedUnref scored_unref(range_dataset);
+
+ std::unique_ptr<IteratorContext> iterator_context;
+ TF_ASSERT_OK(CreateIteratorContext(range_context.get(), &iterator_context));
+ std::unique_ptr<IteratorBase> iterator;
+ TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator",
+ &iterator));
+
+ bool end_of_sequence = false;
+ std::vector<Tensor> out_tensors;
+ while (!end_of_sequence) {
+ TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
+ &end_of_sequence));
+ }
+ std::vector<int> expected_values;
+ for (int i = params.start; (params.end - i) * params.step > 0;
+ i = i + params.step) {
+ expected_values.reserve(1);
+ expected_values.emplace_back(i);
+ }
+ EXPECT_EQ(out_tensors.size(), expected_values.size());
+ for (size_t i = 0; i < out_tensors.size(); ++i) {
+ int64 actual_value = out_tensors[i].flat<int64>()(0);
+ int64 expect_value = expected_values[i];
+ EXPECT_EQ(actual_value, expect_value);
+ }
+}
+
+INSTANTIATE_TEST_CASE_P(RangeDatasetOpTest, DatasetGetNextTest,
+ ::testing::Values(GetNextTestParams(0, 10, 1),
+ GetNextTestParams(0, 10, 3),
+ GetNextTestParams(10, 0, -1),
+ GetNextTestParams(10, 0, -3)));
+
+TEST_F(RangeDatasetOpTest, DatasetName) {
+ int64 start = 0, end = 10, step = 1;
+ int thread_num = 2, cpu_num = 2;
+
+ TF_ASSERT_OK(InitThreadPool(thread_num));
+ TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
+
+ std::unique_ptr<OpKernel> range_kernel;
+ TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>("range", &range_kernel));
+ std::unique_ptr<OpKernelContext> range_context;
+ TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(),
+ &range_context));
+ DatasetBase* range_dataset;
+ TF_ASSERT_OK(
+ CreateDataset(range_kernel.get(), range_context.get(), &range_dataset));
+ core::ScopedUnref scored_unref(range_dataset);
+
+ EXPECT_EQ(range_dataset->name(), kOpName);
+}
+
+TEST_F(RangeDatasetOpTest, DatasetOutputDtypes) {
+ int64 start = 0, end = 10, step = 1;
+ int thread_num = 2, cpu_num = 2;
+
+ TF_ASSERT_OK(InitThreadPool(thread_num));
+ TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
+
+ std::unique_ptr<OpKernel> range_kernel;
+ TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>("range", &range_kernel));
+ std::unique_ptr<OpKernelContext> range_context;
+ TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(),
+ &range_context));
+ DatasetBase* range_dataset;
+ TF_ASSERT_OK(
+ CreateDataset(range_kernel.get(), range_context.get(), &range_dataset));
+ core::ScopedUnref scored_unref(range_dataset);
+
+ DataTypeVector expected_dtypes({DT_INT64});
+ EXPECT_EQ(range_dataset->output_dtypes(), expected_dtypes);
+}
+
+TEST_F(RangeDatasetOpTest, DatasetOutputShapes) {
+ int64 start = 0, end = 10, step = 1;
+ int thread_num = 2, cpu_num = 2;
+
+ TF_ASSERT_OK(InitThreadPool(thread_num));
+ TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
+
+ std::unique_ptr<OpKernel> range_kernel;
+ TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>("range", &range_kernel));
+ std::unique_ptr<OpKernelContext> range_context;
+ TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(),
+ &range_context));
+ DatasetBase* range_dataset;
+ TF_ASSERT_OK(
+ CreateDataset(range_kernel.get(), range_context.get(), &range_dataset));
+ core::ScopedUnref scored_unref(range_dataset);
+
+ std::vector<PartialTensorShape> expected_shapes({PartialTensorShape({})});
+ EXPECT_EQ(range_dataset->output_shapes().size(), expected_shapes.size());
+ for (int i = 0; i < range_dataset->output_shapes().size(); ++i) {
+ EXPECT_TRUE(
+ range_dataset->output_shapes()[i].IsIdenticalTo(expected_shapes[i]));
+ }
+}
+
+struct CardinalityTestParams {
+ explicit CardinalityTestParams(int64 input_start, int64 input_end,
+ int64 input_step,
+ int input_expected_cardinality)
+ : start(input_start),
+ end(input_end),
+ step(input_step),
+ expected_cardinality(input_expected_cardinality) {}
+
+ int64 start;
+ int64 end;
+ int64 step;
+ int expected_cardinality;
+};
+
+struct DatasetCardinalityTest
+ : RangeDatasetOpTest,
+ ::testing::WithParamInterface<CardinalityTestParams> {};
+
+TEST_P(DatasetCardinalityTest, Cardinality) {
+ int thread_num = 2, cpu_num = 2;
+ CardinalityTestParams params = GetParam();
+
+ TF_ASSERT_OK(InitThreadPool(thread_num));
+ TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
+
+ std::unique_ptr<OpKernel> range_kernel;
+ TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>("range", &range_kernel));
+ std::unique_ptr<OpKernelContext> range_context;
+ TF_ASSERT_OK(CreateRangeDatasetContext(params.start, params.end, params.step,
+ range_kernel.get(), &range_context));
+ DatasetBase* range_dataset;
+ TF_ASSERT_OK(
+ CreateDataset(range_kernel.get(), range_context.get(), &range_dataset));
+ core::ScopedUnref scored_unref(range_dataset);
+
+ EXPECT_EQ(range_dataset->Cardinality(), params.expected_cardinality);
+}
+
+INSTANTIATE_TEST_CASE_P(RangeDatasetOpTest, DatasetCardinalityTest,
+ ::testing::Values(CardinalityTestParams(0, 10, 1, 10),
+ CardinalityTestParams(0, 10, 3, 4),
+ CardinalityTestParams(10, 0, -3, 4)));
+
+TEST_F(RangeDatasetOpTest, DatasetSave) {
+ int64 thread_num = 2, cpu_num = 2;
+ int start = 0, end = 10, step = 1;
+
+ TF_ASSERT_OK(InitThreadPool(thread_num));
+ TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
+
+ std::unique_ptr<OpKernel> range_kernel;
+ TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>("range", &range_kernel));
+ std::unique_ptr<OpKernelContext> range_context;
+ TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(),
+ &range_context));
+ DatasetBase* range_dataset;
+ TF_ASSERT_OK(
+ CreateDataset(range_kernel.get(), range_context.get(), &range_dataset));
+ core::ScopedUnref scored_unref(range_dataset);
+
+ std::unique_ptr<SerializationContext> serialization_context;
+ TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
+
+ VariantTensorData data;
+ VariantTensorDataWriter writer(&data);
+ TF_ASSERT_OK(range_dataset->Save(serialization_context.get(), &writer));
+ TF_ASSERT_OK(writer.Flush());
+}
+
+TEST_F(RangeDatasetOpTest, IteratorOutputDtypes) {
+ int64 start = 0, end = 10, step = 1;
+ int thread_num = 2, cpu_num = 2;
+
+ TF_ASSERT_OK(InitThreadPool(thread_num));
+ TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
+
+ std::unique_ptr<OpKernel> range_kernel;
+ TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>("range", &range_kernel));
+ std::unique_ptr<OpKernelContext> range_context;
+ TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(),
+ &range_context));
+ DatasetBase* range_dataset;
+ TF_ASSERT_OK(
+ CreateDataset(range_kernel.get(), range_context.get(), &range_dataset));
+ core::ScopedUnref scored_unref(range_dataset);
+
+ std::unique_ptr<IteratorContext> iterator_context;
+ TF_ASSERT_OK(CreateIteratorContext(range_context.get(), &iterator_context));
+ std::unique_ptr<IteratorBase> iterator;
+ TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator",
+ &iterator));
+
+ DataTypeVector expected_dtypes({DT_INT64});
+ EXPECT_EQ(iterator->output_dtypes(), expected_dtypes);
+}
+
+TEST_F(RangeDatasetOpTest, IteratorOutputShapes) {
+ int64 start = 0, end = 10, step = 1;
+ int thread_num = 2, cpu_num = 2;
+
+ TF_ASSERT_OK(InitThreadPool(thread_num));
+ TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
+
+ std::unique_ptr<OpKernel> range_kernel;
+ TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>("range", &range_kernel));
+ std::unique_ptr<OpKernelContext> range_context;
+ TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(),
+ &range_context));
+ DatasetBase* range_dataset;
+ TF_ASSERT_OK(
+ CreateDataset(range_kernel.get(), range_context.get(), &range_dataset));
+ core::ScopedUnref scored_unref(range_dataset);
+
+ std::unique_ptr<IteratorContext> iterator_context;
+ TF_ASSERT_OK(CreateIteratorContext(range_context.get(), &iterator_context));
+ std::unique_ptr<IteratorBase> iterator;
+ TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator",
+ &iterator));
+
+ std::vector<PartialTensorShape> expected_shapes({PartialTensorShape({})});
+ EXPECT_EQ(iterator->output_shapes().size(), expected_shapes.size());
+ for (int i = 0; i < range_dataset->output_shapes().size(); ++i) {
+ EXPECT_TRUE(iterator->output_shapes()[i].IsIdenticalTo(expected_shapes[i]));
+ }
+}
+
+TEST_F(RangeDatasetOpTest, IteratorOutputPrefix) {
+ int64 start = 0, end = 10, step = 1;
+ int thread_num = 2, cpu_num = 2;
+
+ TF_ASSERT_OK(InitThreadPool(thread_num));
+ TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
+
+ std::unique_ptr<OpKernel> range_kernel;
+ TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>("range", &range_kernel));
+ std::unique_ptr<OpKernelContext> range_context;
+ TF_ASSERT_OK(CreateRangeDatasetContext(start, end, step, range_kernel.get(),
+ &range_context));
+ DatasetBase* range_dataset;
+ TF_ASSERT_OK(
+ CreateDataset(range_kernel.get(), range_context.get(), &range_dataset));
+ core::ScopedUnref scored_unref(range_dataset);
+
+ std::unique_ptr<IteratorContext> iterator_context;
+ TF_ASSERT_OK(CreateIteratorContext(range_context.get(), &iterator_context));
+ std::unique_ptr<IteratorBase> iterator;
+ TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator",
+ &iterator));
+
+ EXPECT_EQ(iterator->prefix(), "Iterator::Range");
+}
+
+struct RoundtripTestParams {
+ explicit RoundtripTestParams(int64 input_start, int64 input_end,
+ int64 input_step, int input_breakpoint)
+ : start(input_start),
+ end(input_end),
+ step(input_step),
+ breakpoint(input_breakpoint) {}
+
+ int64 start;
+ int64 end;
+ int64 step;
+ int breakpoint;
+};
+
+struct IteratorRoundtripTest
+ : RangeDatasetOpTest,
+ ::testing::WithParamInterface<RoundtripTestParams> {};
+
+TEST_P(IteratorRoundtripTest, Roundtrip) {
+ int thread_num = 2, cpu_num = 2;
+ RoundtripTestParams params = GetParam();
+
+ TF_ASSERT_OK(InitThreadPool(thread_num));
+ TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
+
+ std::unique_ptr<OpKernel> range_kernel;
+ TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>("range", &range_kernel));
+ std::unique_ptr<OpKernelContext> range_context;
+ TF_ASSERT_OK(CreateRangeDatasetContext(params.start, params.end, params.step,
+ range_kernel.get(), &range_context));
+ DatasetBase* range_dataset;
+ TF_ASSERT_OK(
+ CreateDataset(range_kernel.get(), range_context.get(), &range_dataset));
+ core::ScopedUnref scored_unref(range_dataset);
+
+ std::unique_ptr<IteratorContext> iterator_context;
+ TF_ASSERT_OK(CreateIteratorContext(range_context.get(), &iterator_context));
+ std::unique_ptr<IteratorBase> iterator;
+ TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(), "Iterator",
+ &iterator));
+
+ std::vector<Tensor> out_tensors;
+ bool end_of_sequence = false;
+ int64 cur_val = params.start - params.step;
+ for (int i = 0; i < params.breakpoint; i++) {
+ if (!end_of_sequence) {
+ TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
+ &end_of_sequence));
+ cur_val = ((params.end - cur_val - params.step) * params.step > 0)
+ ? cur_val + params.step
+ : cur_val;
+ }
+ }
+
+ std::unique_ptr<SerializationContext> serialization_context;
+ TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
+ VariantTensorData data;
+ VariantTensorDataWriter writer(&data);
+ TF_ASSERT_OK(iterator->Save(serialization_context.get(), &writer));
+ TF_ASSERT_OK(writer.Flush());
+ VariantTensorDataReader reader(&data);
+ TF_ASSERT_OK(iterator->Restore(iterator_context.get(), &reader));
+ TF_EXPECT_OK(iterator->GetNext(iterator_context.get(), &out_tensors,
+ &end_of_sequence));
+ int64 expect_next = ((params.end - cur_val - params.step) * params.step > 0)
+ ? cur_val + params.step
+ : cur_val;
+ EXPECT_EQ(out_tensors.back().flat<int64>()(0), expect_next);
+}
+
+INSTANTIATE_TEST_CASE_P(
+ RangeDatasetOpTest, IteratorRoundtripTest,
+ ::testing::Values(
+ RoundtripTestParams(0, 10, 2, 0), // unused_iterator
+ RoundtripTestParams(0, 10, 2, 4), // fully_used_iterator_increase
+ RoundtripTestParams(10, 0, -2, 4), // fully_used_iterator_decrease
+ RoundtripTestParams(0, 10, 2, 6))); // exhausted_iterator
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/shard_dataset_op.cc b/tensorflow/core/kernels/data/shard_dataset_op.cc
new file mode 100644
index 0000000..9bb6491
--- /dev/null
+++ b/tensorflow/core/kernels/data/shard_dataset_op.cc
@@ -0,0 +1,195 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/util/batch_util.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+// See documentation in ../../ops/dataset_ops.cc for a high-level
+// description of the following op.
+
+class ShardDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit ShardDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {}
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ int64 index = 0;
+ int64 num_shards = 0;
+
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<int64>(ctx, "num_shards", &num_shards));
+ OP_REQUIRES(
+ ctx, num_shards > 0,
+ errors::InvalidArgument("Number of shards must be greater than zero "
+ "(currently num_shards = ",
+ num_shards, ")."));
+
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "index", &index));
+ OP_REQUIRES(
+ ctx, index >= 0 && index < num_shards,
+ errors::InvalidArgument("Index must be between 0 and ", num_shards - 1,
+ " (currently index = ", index, ")."));
+
+ *output = new Dataset(ctx, num_shards, index, input);
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, int64 num_shards, int64 index,
+ const DatasetBase* input)
+ : DatasetBase(DatasetContext(ctx)),
+ num_shards_(num_shards),
+ index_(index),
+ input_(input) {
+ input_->Ref();
+ }
+
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return absl::make_unique<Iterator>(
+ Iterator::Params{this, strings::StrCat(prefix, "::Shard")});
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return input_->output_dtypes();
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return input_->output_shapes();
+ }
+
+ string DebugString() const override {
+ return strings::StrCat("ShardDatasetOp(", num_shards_, ", ", index_,
+ ")::Dataset");
+ }
+
+ int64 Cardinality() const override {
+ int64 n = input_->Cardinality();
+ if (n == kInfiniteCardinality || n == kUnknownCardinality) {
+ return n;
+ }
+ return n / num_shards_ + (index_ < n % num_shards_ ? 1 : 0);
+ }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
+ Node* num_shards = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(num_shards_, &num_shards));
+ Node* index = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(index_, &index));
+ TF_RETURN_IF_ERROR(
+ b->AddDataset(this, {input_graph_node, num_shards, index}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params), next_index_(0) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+
+ if (!input_impl_) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+
+ std::vector<Tensor> result;
+ do {
+ result.clear();
+ TF_RETURN_IF_ERROR(
+ input_impl_->GetNext(ctx, &result, end_of_sequence));
+ if (*end_of_sequence) {
+ input_impl_.reset();
+ return Status::OK();
+ }
+ } while ((next_index_++ % dataset()->num_shards_) != dataset()->index_);
+
+ *out_tensors = std::move(result);
+ return Status::OK();
+ }
+
+ protected:
+ std::shared_ptr<model::Node> CreateNode(
+ IteratorContext* ctx, model::Node::Args args) const override {
+ return model::MakeKnownRatioNode(std::move(args),
+ dataset()->num_shards_);
+ }
+
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ if (!input_impl_) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("input_impl_empty"), ""));
+ } else {
+ TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("next_index"), next_index_));
+ }
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ if (!reader->Contains(full_name("input_impl_empty"))) {
+ TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name("next_index"), &next_index_));
+ } else {
+ input_impl_.reset();
+ }
+ return Status::OK();
+ }
+
+ private:
+ mutex mu_;
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ int64 next_index_ GUARDED_BY(mu_);
+ };
+
+ const int64 num_shards_;
+ const int64 index_;
+ const DatasetBase* const input_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("ShardDataset").Device(DEVICE_CPU),
+ ShardDatasetOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/decode_proto_op.cc b/tensorflow/core/kernels/decode_proto_op.cc
index b54e1ea..06dc766 100644
--- a/tensorflow/core/kernels/decode_proto_op.cc
+++ b/tensorflow/core/kernels/decode_proto_op.cc
@@ -31,6 +31,7 @@
#include <string>
#include <vector>
+#include "absl/container/flat_hash_map.h"
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
@@ -625,8 +626,37 @@
// Gather the field descriptors and check that requested output types match.
int field_index = 0;
std::vector<const FieldDescriptor*> field_descs;
+ std::vector<const FieldDescriptor*> exts;
+ absl::flat_hash_map<string, const FieldDescriptor*> ext_name_to_field;
+ std::vector<const FieldDescriptor*>::iterator ext_it = exts.begin();
for (const string& name : field_names) {
auto fd = message_desc->FindFieldByName(name);
+ if (fd == nullptr) {
+ // If field can't be found in original message, try to find a matching
+ // extension (by its full_name). First check a hashmap for a matching
+ // extension, and if not found, then iterate through available
+ // extensions to find a match (updating the hashmap while iterating.)
+ auto lookup_result = ext_name_to_field.find(name);
+ if (lookup_result != ext_name_to_field.end()) {
+ fd = lookup_result->second;
+ } else {
+ if (ext_it == exts.begin()) {
+ desc_pool->FindAllExtensions(message_desc, &exts);
+ ext_it = exts.begin();
+ }
+ while (ext_it != exts.end()) {
+ auto ext_name = (*ext_it)->full_name();
+ auto ext_field = *ext_it;
+ ++ext_it;
+
+ ext_name_to_field.insert({ext_name, ext_field});
+ if (ext_name == name) {
+ fd = ext_field;
+ break;
+ }
+ }
+ }
+ }
OP_REQUIRES(context, fd != nullptr,
errors::InvalidArgument("Unknown field: ", name,
" in message type ", message_type));
diff --git a/tensorflow/core/kernels/functional_ops.cc b/tensorflow/core/kernels/functional_ops.cc
index 5ecb203..a3e7b53 100644
--- a/tensorflow/core/kernels/functional_ops.cc
+++ b/tensorflow/core/kernels/functional_ops.cc
@@ -120,6 +120,7 @@
opts->stats_collector = ctx->stats_collector();
}
opts->runner = ctx->runner();
+ opts->step_container = ctx->step_container();
}
class IfOp : public AsyncOpKernel {
diff --git a/tensorflow/core/kernels/list_kernels.cc b/tensorflow/core/kernels/list_kernels.cc
index 1e449dd..1286381 100644
--- a/tensorflow/core/kernels/list_kernels.cc
+++ b/tensorflow/core/kernels/list_kernels.cc
@@ -148,6 +148,7 @@
if (t.shape() == TensorShape({})) {
if ((t.dtype() == DT_INT32 && t.scalar<int32>()() == -1) ||
(t.dtype() == DT_INT64 && t.scalar<int64>()() == -1)) {
+ *out = PartialTensorShape();
return Status::OK();
}
return errors::InvalidArgument(
@@ -636,6 +637,10 @@
.TypeConstraint<T>("element_dtype") \
.Device(DEVICE_CPU), \
TensorListConcat<CPUDevice, T>) \
+ REGISTER_KERNEL_BUILDER(Name("TensorListConcatV2") \
+ .TypeConstraint<T>("element_dtype") \
+ .Device(DEVICE_CPU), \
+ TensorListConcat<CPUDevice, T>) \
REGISTER_KERNEL_BUILDER(Name("TensorListGetItem") \
.TypeConstraint<T>("element_dtype") \
.Device(DEVICE_CPU), \
diff --git a/tensorflow/core/kernels/list_kernels.cu.cc b/tensorflow/core/kernels/list_kernels.cu.cc
index 5259389..652ca2b 100644
--- a/tensorflow/core/kernels/list_kernels.cu.cc
+++ b/tensorflow/core/kernels/list_kernels.cu.cc
@@ -64,6 +64,13 @@
.Device(DEVICE_GPU) \
.HostMemory("lengths"), \
TensorListConcat<GPUDevice, T>) \
+ REGISTER_KERNEL_BUILDER(Name("TensorListConcatV2") \
+ .TypeConstraint<T>("element_dtype") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("leading_dims") \
+ .HostMemory("element_shape") \
+ .HostMemory("lengths"), \
+ TensorListConcat<GPUDevice, T>) \
REGISTER_KERNEL_BUILDER(Name("TensorListPushBackBatch") \
.TypeConstraint<T>("element_dtype") \
.Device(DEVICE_GPU), \
diff --git a/tensorflow/core/kernels/list_kernels.h b/tensorflow/core/kernels/list_kernels.h
index 7b3ff07..c25e9ce 100644
--- a/tensorflow/core/kernels/list_kernels.h
+++ b/tensorflow/core/kernels/list_kernels.h
@@ -209,11 +209,29 @@
OP_REQUIRES_OK(
c, GetElementShapeFromInput(c, *l, 2, &partial_element_shape));
TensorShape element_shape;
+ // If l->element_shape and the element_shape input are both not fully
+ // defined, try to infer the shape from other list elements. This requires
+ // that all initialized list elements have the same shape.
+ // NOTE(srbs): This might be a performance bottleneck since we are
+ // iterating over the entire list here. This is necessary for feature
+ // parity with TensorArray.read. TensorArray has a mode in which all
+ // elements are required to be of the same shape, TensorList does not.
+ // In that mode TensorArray sets the array's element_shape on the first
+ // write call. We could do something similar here if needed.
+ if (!partial_element_shape.IsFullyDefined()) {
+ for (const Tensor& t : l->tensors) {
+ if (t.dtype() != DT_INVALID) {
+ PartialTensorShape tmp = partial_element_shape;
+ OP_REQUIRES_OK(c, tmp.MergeWith(t.shape(), &partial_element_shape));
+ }
+ }
+ }
OP_REQUIRES(
c, partial_element_shape.AsTensorShape(&element_shape),
errors::InvalidArgument("Trying to read an uninitialized tensor but ",
- "element_shape is not fully defined.",
- partial_element_shape.DebugString()));
+ "element_shape is not fully defined: ",
+ partial_element_shape.DebugString(),
+ " and no list element is set."));
Tensor* result;
AllocatorAttributes attr;
if (element_dtype_ == DT_VARIANT) {
@@ -327,60 +345,83 @@
errors::InvalidArgument(
"Invalid data types; op elements ", DataTypeString(element_dtype_),
" but list elements ", DataTypeString(tensor_list->element_dtype)));
- // If the TensorList is empty, its element_shape must be fully defined
- // except for the first dimension.
- if (!element_shape_except_first_dim_.IsFullyDefined()) {
- if (!tensor_list->element_shape.unknown_rank()) {
- OP_REQUIRES(c, tensor_list->element_shape.dims() >= 1,
- errors::InvalidArgument(
- "Concat requires elements to be at least vectors, ",
- "found scalars instead."));
- PartialTensorShape shape_except_first_dim(
- gtl::ArraySlice<int64>(tensor_list->element_shape.dim_sizes())
- .subspan(1));
- PartialTensorShape tmp = element_shape_except_first_dim_;
- OP_REQUIRES_OK(c, tmp.MergeWith(shape_except_first_dim,
- &element_shape_except_first_dim_));
- }
+ // The leading dimension of all list elements if they are all the same.
+ // This is used as the leading dim of uninitialized tensors in the list
+ // if leading_dims is not provided.
+ int64 first_dim = -1;
+ if (c->num_inputs() > 1) {
+ // TensorListConcatV2
+ PartialTensorShape element_shape;
+ OP_REQUIRES_OK(
+ c, GetElementShapeFromInput(c, *tensor_list, 1, &element_shape));
+ OP_REQUIRES(c, element_shape.unknown_rank() || element_shape.dims() >= 1,
+ errors::InvalidArgument(
+ "Concat requires elements to be at least vectors, ",
+ "found scalars instead."));
+ // Split `element_shape` into `first_dim` and
+ // `element_shape_except_first_dim_`.
+ first_dim = element_shape.dim_size(0);
+ element_shape_except_first_dim_ = element_shape;
+ element_shape_except_first_dim_.RemoveDim(0);
}
+ // If the TensorList is empty, element_shape_except_first_dim_ must be fully
+ // defined.
OP_REQUIRES(c,
!tensor_list->tensors.empty() ||
element_shape_except_first_dim_.IsFullyDefined(),
errors::InvalidArgument(
"All except the first dimension must be fully defined ",
"when concating an empty tensor list. element_shape: ",
- tensor_list->element_shape.DebugString()));
- // 1. Compute the shape of the output tensor.
- // If `element_shape_except_first_dim_` is fully-defined we just prepend the
- // leading dim to it. Otherwise we use the shape of the first element tensor
- // and check to make sure shapes of all tensors are compatible.
- TensorShape output_shape;
- if (!element_shape_except_first_dim_.AsTensorShape(&output_shape)) {
- const Tensor& element_tensor = tensor_list->tensors[0];
- OP_REQUIRES(
- c, TensorShapeUtils::IsVectorOrHigher(element_tensor.shape()),
- errors::InvalidArgument("Concat saw a scalar shape at index ", 0,
- " but requires at least vectors."));
- output_shape =
- TensorShape(gtl::ArraySlice<int64>(element_tensor.shape().dim_sizes())
- .subspan(1));
- for (int i = 1; i < tensor_list->tensors.size(); ++i) {
- const Tensor& element_tensor = tensor_list->tensors[i];
- OP_REQUIRES(
- c, TensorShapeUtils::IsVectorOrHigher(element_tensor.shape()),
- errors::InvalidArgument("Concat saw a scalar shape at index ", i,
- " but requires at least vectors."));
- TensorShape actual_shape(
- gtl::ArraySlice<int64>(element_tensor.shape().dim_sizes())
- .subspan(1));
- OP_REQUIRES(c, actual_shape.dim_sizes() == output_shape.dim_sizes(),
- errors::InvalidArgument(
- "Tried to concat tensors with unequal shapes: ",
- output_shape.DebugString(), " vs ",
- actual_shape.DebugString()));
+ element_shape_except_first_dim_.DebugString()));
+ // 1. Check that `element_shape_except_first_dim_` input tensor is
+ // compatible with the shapes of element tensors.
+ // 2. Check that the elements have the same shape except the first dim.
+ // 3. If `first_dim` is known, check that it is compatible with the leading
+ // dims of all elements.
+ // 4. If `first_dim` is unknown (-1), check whether all initialized
+ // elements have the same leading dim and if so set `first_dim` to that
+ // value.
+ if (!tensor_list->element_shape.IsFullyDefined()) {
+ bool check_dim = (first_dim == -1);
+ int64 inferred_first_dim = first_dim;
+ for (int i = 0; i < tensor_list->tensors.size(); ++i) {
+ const Tensor& t = tensor_list->tensors[i];
+ if (t.dtype() != DT_INVALID) {
+ PartialTensorShape tmp = element_shape_except_first_dim_;
+ OP_REQUIRES(
+ c, TensorShapeUtils::IsVectorOrHigher(t.shape()),
+ errors::InvalidArgument("Concat saw a scalar shape at index ", i,
+ " but requires at least vectors."));
+ TensorShape shape_except_first_dim = TensorShape(
+ gtl::ArraySlice<int64>(t.shape().dim_sizes()).subspan(1));
+ OP_REQUIRES_OK(c, tmp.MergeWith(shape_except_first_dim,
+ &element_shape_except_first_dim_));
+ OP_REQUIRES(c, first_dim == -1 || first_dim == t.shape().dim_size(0),
+ errors::InvalidArgument(
+ "First entry of element_shape input does not match ",
+ "the first dim of list element at index: ", i,
+ " Expected: ", first_dim,
+ " Actual: ", t.shape().dim_size(0)));
+ if (check_dim) {
+ if (inferred_first_dim == -1) {
+ inferred_first_dim = t.shape().dim_size(0);
+ } else if (inferred_first_dim != t.shape().dim_size(0)) {
+ inferred_first_dim = -1;
+ check_dim = false;
+ }
+ }
+ }
}
+ first_dim = inferred_first_dim;
}
- // 2. Build the lengths_tensor and leading dim of the output tensor by
+ TensorShape output_shape;
+ OP_REQUIRES(
+ c, element_shape_except_first_dim_.AsTensorShape(&output_shape),
+ errors::InvalidArgument(
+ "Trying to concat list with only uninitialized tensors ",
+ "but element_shape_except_first_dim_ is not fully defined: ",
+ element_shape_except_first_dim_.DebugString()));
+ // Build the lengths_tensor and leading dim of the output tensor by
// iterating over all element tensors.
Tensor* lengths_tensor = nullptr;
OP_REQUIRES_OK(
@@ -391,13 +432,36 @@
auto lengths_tensor_vec = lengths_tensor->vec<int64>();
int64 leading_dim = 0;
for (size_t i = 0; i < tensor_list->tensors.size(); i++) {
- int64 dim = tensor_list->tensors[i].shape().dim_size(0);
+ int64 dim;
+ if (tensor_list->tensors[i].dtype() != DT_INVALID) {
+ dim = tensor_list->tensors[i].shape().dim_size(0);
+ } else {
+ // If leading_dims is not provided or does not contain an entry for
+ // index i use the inferred `first_dim` if set.
+ if ((c->num_inputs() <= 2 || i >= c->input(2).NumElements()) &&
+ first_dim != -1) {
+ dim = first_dim;
+ } else {
+ OP_REQUIRES(c, c->num_inputs() > 2,
+ errors::InvalidArgument(
+ "Concating lists with uninitialized tensors is not ",
+ "supported in this version of TensorListConcat. ",
+ "Consider updating your GraphDef to run the newer ",
+ "version."));
+ OP_REQUIRES(c, i < c->input(2).NumElements(),
+ errors::InvalidArgument(
+ "List contains uninitialized tensor at index ", i,
+ " but leading_dims has only ",
+ c->input(2).NumElements(), " elements."));
+ dim = c->input(2).vec<int64>()(i);
+ }
+ }
leading_dim += dim;
lengths_tensor_vec(i) = dim;
}
output_shape.InsertDim(0, leading_dim);
Tensor* output;
- // 3. Allocate the output tensor and fill it up with the concated element
+ // Allocate the output tensor and fill it up with the concated element
// tensors.
OP_REQUIRES_OK(c, c->allocate_output(0, output_shape, &output));
if (output->NumElements() == 0) {
@@ -406,9 +470,31 @@
ConstMatrixVector inputs_flat;
inputs_flat.reserve(tensor_list->tensors.size());
- for (const auto& element_tensor : tensor_list->tensors) {
- inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
- element_tensor.shaped<T, 2>({1, element_tensor.NumElements()})));
+ // Store the zeros tensors in a vector to prevent them from being GC'ed till
+ // concat is complete.
+ std::vector<Tensor> zeros_vec;
+ for (int i = 0; i < tensor_list->tensors.size(); i++) {
+ const Tensor& element_tensor = tensor_list->tensors[i];
+ if (element_tensor.dtype() != DT_INVALID) {
+ inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
+ element_tensor.shaped<T, 2>({1, element_tensor.NumElements()})));
+ } else {
+ AllocatorAttributes attr;
+ if (element_dtype_ == DT_VARIANT) {
+ attr.set_on_host(true);
+ }
+ TensorShape element_shape = output_shape;
+ element_shape.set_dim(0, lengths_tensor_vec(i));
+ zeros_vec.emplace_back();
+ Tensor& zeros = zeros_vec.back();
+ OP_REQUIRES_OK(
+ c, c->allocate_temp(element_dtype_, element_shape, &zeros, attr));
+ functor::SetZeroFunctor<Device, T>()(c->eigen_device<Device>(),
+ zeros.flat<T>());
+ inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
+ const_cast<const Tensor&>(zeros).shaped<T, 2>(
+ {1, zeros.NumElements()})));
+ }
}
auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
@@ -522,7 +608,7 @@
errors::InvalidArgument(
"Invalid data types; op elements ", DataTypeString(element_dtype_),
" but list elements ", DataTypeString(tensor_list->element_dtype)));
- Tensor indices = c->input(1);
+ const Tensor& indices = c->input(1);
PartialTensorShape partial_element_shape;
OP_REQUIRES_OK(c, GetElementShapeFromInput(c, *tensor_list, 2,
&partial_element_shape));
diff --git a/tensorflow/core/kernels/logging_ops.cc b/tensorflow/core/kernels/logging_ops.cc
index 2599340..e611ae2 100644
--- a/tensorflow/core/kernels/logging_ops.cc
+++ b/tensorflow/core/kernels/logging_ops.cc
@@ -13,7 +13,10 @@
limitations under the License.
==============================================================================*/
+#include "tensorflow/core/kernels/logging_ops.h"
+
#include <iostream>
+
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -48,6 +51,22 @@
} // namespace
+namespace logging {
+
+typedef std::vector<void (*)(const char*)> Listeners;
+
+Listeners* GetListeners() {
+ static Listeners* listeners = new Listeners;
+ return listeners;
+}
+
+bool RegisterListener(void (*listener)(const char*)) {
+ GetListeners()->push_back(listener);
+ return true;
+}
+
+} // end namespace logging
+
class AssertOp : public OpKernel {
public:
explicit AssertOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
@@ -157,7 +176,12 @@
OP_REQUIRES_OK(ctx, AppendStringToFile(file_path_, msg, ctx->env()));
return;
}
- if (output_stream_ == "stdout") {
+ auto listeners = logging::GetListeners();
+ if (!listeners->empty()) {
+ for (auto& listener : *listeners) {
+ listener(msg.c_str());
+ }
+ } else if (output_stream_ == "stdout") {
std::cout << msg << std::endl;
} else if (output_stream_ == "stderr") {
std::cerr << msg << std::endl;
diff --git a/tensorflow/core/kernels/logging_ops.h b/tensorflow/core/kernels/logging_ops.h
new file mode 100644
index 0000000..92a8d63
--- /dev/null
+++ b/tensorflow/core/kernels/logging_ops.h
@@ -0,0 +1,33 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_LOGGING_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_LOGGING_OPS_H_
+
+#include "absl/strings/str_cat.h"
+#include "absl/strings/str_split.h"
+
+namespace tensorflow {
+
+namespace logging {
+
+// Register a listener method to call on any printed messages.
+// Returns true if it is successfully registered.
+bool RegisterListener(void (*listener)(const char*));
+
+} // namespace logging
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_LOGGING_OPS_H_
diff --git a/tensorflow/core/kernels/mkl_concat_op.cc b/tensorflow/core/kernels/mkl_concat_op.cc
index 3a5c874..b95bbca 100644
--- a/tensorflow/core/kernels/mkl_concat_op.cc
+++ b/tensorflow/core/kernels/mkl_concat_op.cc
@@ -25,12 +25,14 @@
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/concat_lib.h"
+#include "tensorflow/core/kernels/concat_lib_cpu.h"
+#include "tensorflow/core/kernels/quantization_utils.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/mkl_util.h"
using mkldnn::concat;
using mkldnn::stream;
-#include "tensorflow/core/util/mkl_util.h"
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
@@ -226,8 +228,50 @@
// format and avoid calling eigen version.
if (!are_all_tf_inputs && !are_all_mkl_inputs) invoke_eigen = true;
+ OpInputList input_mins, input_maxes;
+ if (std::is_same<T, qint8>::value || std::is_same<T, quint8>::value) {
+ // MKL-DNN concat does not support input tensors that have different
+ // ranges. Check if the ranges of the all input tensors are the same.
+ // If not, forward it to Eigen implementation.
+
+ OP_REQUIRES_OK(context, context->input_list("input_mins", &input_mins));
+ OP_REQUIRES(context, (input_mins.size() == N),
+ errors::InvalidArgument(
+ "QuantizedConcatOp : Expected mins input list length ",
+ input_mins.size(), " to equal values length ", N));
+
+ OP_REQUIRES_OK(context,
+ context->input_list("input_maxes", &input_maxes));
+ OP_REQUIRES(context, (input_maxes.size() == N),
+ errors::InvalidArgument(
+ "QuantizedConcatOp : Expected maxes input list length ",
+ input_maxes.size(), " to equal values length ", N));
+ float input_min = input_mins[0].flat<float>()(0);
+ float input_max = input_maxes[0].flat<float>()(0);
+ const float eps = 1.0e-6;
+ for (int i = 1; i < N; ++i) {
+ float min = input_mins[i].flat<float>()(0);
+ float max = input_maxes[i].flat<float>()(0);
+
+ if (fabs(input_min - min) > eps || fabs(input_max - max) > eps) {
+ invoke_eigen = true;
+ break;
+ }
+ }
+ }
+
// Call Eigen library
if (invoke_eigen) {
+ // MKL-DNN quantized concat does not support input tensors with
+ // different ranges.
+ // TODO (mabuzain): Add quantized version of CallEigen() to support
+ // this case.
+ OP_REQUIRES(
+ context,
+ (!std::is_same<T, qint8>::value && !std::is_same<T, quint8>::value),
+ errors::Unimplemented("MKL DNN quantized concat does not "
+ "support input tensors that have "
+ "different ranges"));
CallEigenVersion(context, input_tensors, mkl_input_shapes);
return;
}
@@ -374,6 +418,23 @@
std::vector<primitive> net;
net.push_back(concat_op);
stream(stream::kind::eager).submit(net).wait();
+
+ // For quantized concat, min and max outputs are also computed.
+ if (std::is_same<T, qint8>::value || std::is_same<T, quint8>::value) {
+ Tensor* output_min = nullptr;
+ Tensor* output_max = nullptr;
+ MklDnnShape output_min_mkl_shape, output_max_mkl_shape;
+ output_min_mkl_shape.SetMklTensor(false);
+ output_max_mkl_shape.SetMklTensor(false);
+ AllocateOutputSetMklShape(context, 1, &output_min, {},
+ output_min_mkl_shape);
+ AllocateOutputSetMklShape(context, 2, &output_max, {},
+ output_max_mkl_shape);
+ // All input tensors should have the same range, just use the
+ // first one
+ output_min->flat<float>()(0) = input_mins[0].flat<float>()(0);
+ output_max->flat<float>()(0) = input_maxes[0].flat<float>()(0);
+ }
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
@@ -490,6 +551,20 @@
TF_CALL_float(REGISTER_MKL_CPU);
+REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConcatV2")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<quint8>("T")
+ .HostMemory("axis")
+ .Label(mkl_op_registry::kMklQuantizedOpLabel),
+ MklConcatOp<CPUDevice, quint8, NAME_IS_AXIS>)
+
+REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConcatV2")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<qint8>("T")
+ .HostMemory("axis")
+ .Label(mkl_op_registry::kMklQuantizedOpLabel),
+ MklConcatOp<CPUDevice, qint8, NAME_IS_AXIS>)
+
#undef REGISTER_CONCAT_MKL
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mkl_fused_ops_test.cc b/tensorflow/core/kernels/mkl_fused_ops_test.cc
index 756ee90..288515d 100644
--- a/tensorflow/core/kernels/mkl_fused_ops_test.cc
+++ b/tensorflow/core/kernels/mkl_fused_ops_test.cc
@@ -434,7 +434,7 @@
// Compare outputs to expected results
const Tensor& output = *GetOutput(0);
const Tensor& output_layout = *GetOutput(2);
- ConvMklToTF<T> conv_comp;
+ CommonTestUtilities<T> conv_comp;
conv_comp.ConvertAndCompare(dtype, output, output_layout, expected);
// TODO(bhavanis): For now, we rely on internal performance tests to
@@ -446,7 +446,7 @@
// Compare output to expected results
const Tensor& output_new = *GetOutput(0);
const Tensor& output_layout_new = *GetOutput(2);
- ConvMklToTF<T> conv_comp_new;
+ CommonTestUtilities<T> conv_comp_new;
conv_comp_new.ConvertAndCompare(dtype, output_new, output_layout_new,
expected);
}
diff --git a/tensorflow/core/kernels/mkl_quantized_concat_op_test.cc b/tensorflow/core/kernels/mkl_quantized_concat_op_test.cc
new file mode 100644
index 0000000..fc68480
--- /dev/null
+++ b/tensorflow/core/kernels/mkl_quantized_concat_op_test.cc
@@ -0,0 +1,234 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
+
+#define EIGEN_USE_THREADS
+
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/kernels/quantization_utils.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+
+using test::graph::Constant;
+
+static const uint8 dummy_tensor[] = {0, 0, 0, 0, 0, 0, 0, 0};
+static const TensorShape dummy_shape({8});
+
+// Helper class for converting MKL tensors to TF tensors and comparing to
+// expected values
+
+class ConvMklToTF : public OpsTestBase {
+ public:
+ template <typename T>
+ void ConvertMKL2TF(DataType dtype, const Tensor& first, const Tensor& second,
+ Tensor& output) {
+ // Create an MKL to TF conversion node and execute it
+ TF_EXPECT_OK(NodeDefBuilder("mkl_to_tf_op", "_MklToTf")
+ .Input(FakeInput(dtype)) // Input
+ .Input(FakeInput(DT_UINT8)) // MKL second tensor
+ .Attr("T", dtype)
+ .Attr("_kernel", "MklOp")
+ .Finalize(node_def()));
+ TF_EXPECT_OK(InitOp());
+ AddInputFromArray<T>(first.shape(), first.flat<T>());
+ AddInputFromArray<uint8>(second.shape(), second.flat<uint8>());
+ TF_ASSERT_OK(RunOpKernel());
+
+ output = *GetOutput(0);
+ }
+ void TestBody(){};
+};
+
+class QuantizedConcatTest : public OpsTestBase {
+ protected:
+ QuantizedConcatTest() {}
+
+ void TestSmall8Bit(float first_min, float first_max, float second_min,
+ float second_max);
+ void TestSecondDim8Bit(float first_min, float first_max, float second_min,
+ float second_max);
+};
+
+TEST_F(QuantizedConcatTest, Small8BitSameRange) {
+ // Range for both is the same, so impl can use memcpy.
+ TestSmall8Bit(0.0f, 255.0f, 0.0f, 255.0f);
+}
+
+void QuantizedConcatTest::TestSmall8Bit(float first_min, float first_max,
+ float second_min, float second_max) {
+ TF_ASSERT_OK(NodeDefBuilder("quantized_concat_op", "_MklQuantizedConcatV2")
+ .Input(FakeInput(2, DT_QUINT8))
+ .Input(FakeInput(DT_INT32))
+ .Input(FakeInput(2, DT_FLOAT))
+ .Input(FakeInput(2, DT_FLOAT))
+ .Input(FakeInput(2, DT_UINT8)) // MKL second tensor
+ .Input(FakeInput(DT_UINT8)) // MKL second tensor
+ .Input(FakeInput(2, DT_UINT8)) // MKL second tensor
+ .Input(FakeInput(2, DT_UINT8)) // MKL second tensor
+ .Attr("N", 2)
+ .Attr("T", DataTypeToEnum<quint8>::v())
+ .Attr("Tidx", DT_INT32)
+ .Attr("_kernel", "QuantizedMklOp")
+ .Finalize(node_def()));
+ TF_ASSERT_OK(InitOp());
+ const int first_batch = 2;
+ const int first_height = 2;
+ const int first_width = 3;
+ const int first_depth = 1;
+ Tensor first_float(DT_FLOAT,
+ {first_batch, first_height, first_width, first_depth});
+ test::FillValues<float>(&first_float,
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+ Tensor first_quantized =
+ FloatTensorToQuantized<quint8>(first_float, first_min, first_max);
+
+ const int second_batch = 2;
+ const int second_height = 2;
+ const int second_width = 3;
+ const int second_depth = 1;
+ Tensor second_float(
+ DT_FLOAT, {second_batch, second_height, second_width, second_depth});
+ test::FillValues<float>(&second_float,
+ {13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
+ Tensor second_quantized =
+ FloatTensorToQuantized<quint8>(second_float, second_min, second_max);
+
+ const int expected_batch = first_batch + second_batch;
+ Tensor expected_float(
+ DT_FLOAT, {expected_batch, first_height, first_width, first_depth});
+ test::FillValues<float>(&expected_float,
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
+ 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
+
+ AddInputFromArray<quint8>(first_quantized.shape(),
+ first_quantized.flat<quint8>());
+ AddInputFromArray<quint8>(second_quantized.shape(),
+ second_quantized.flat<quint8>());
+ AddInputFromArray<int32>(TensorShape({}), {0});
+ AddInputFromArray<float>(TensorShape({}), {first_min});
+ AddInputFromArray<float>(TensorShape({}), {second_min});
+ AddInputFromArray<float>(TensorShape({}), {first_max});
+ AddInputFromArray<float>(TensorShape({}), {second_max});
+ AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
+ AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
+ AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
+ AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
+ TF_ASSERT_OK(RunOpKernel());
+ const Tensor& output_quantized = *GetOutput(0);
+ const float output_min = GetOutput(1)->flat<float>()(0);
+ const float output_max = GetOutput(2)->flat<float>()(0);
+ Tensor output_float =
+ QuantizedTensorToFloat<quint8>(output_quantized, output_min, output_max);
+ test::ExpectTensorNear<float>(expected_float, output_float, 0.2);
+}
+
+TEST_F(QuantizedConcatTest, SecondDim8BitSameRange) {
+ TestSecondDim8Bit(-10.0f, 150.0f, -10.0f, 150.0f);
+}
+
+void QuantizedConcatTest::TestSecondDim8Bit(float first_min, float first_max,
+ float second_min,
+ float second_max) {
+ TF_ASSERT_OK(NodeDefBuilder("quantized_concat_op", "_MklQuantizedConcatV2")
+ .Input(FakeInput(2, DT_QUINT8))
+ .Input(FakeInput(DT_INT32))
+ .Input(FakeInput(2, DT_FLOAT))
+ .Input(FakeInput(2, DT_FLOAT))
+ .Input(FakeInput(2, DT_UINT8)) // MKL second tensor
+ .Input(FakeInput(DT_UINT8)) // MKL second tensor
+ .Input(FakeInput(2, DT_UINT8)) // MKL second tensor
+ .Input(FakeInput(2, DT_UINT8)) // MKL second tensor
+ .Attr("N", 2)
+ .Attr("T", DataTypeToEnum<quint8>::v())
+ .Attr("Tidx", DT_INT32)
+ .Attr("_kernel", "QuantizedMklOp")
+ .Finalize(node_def()));
+ TF_ASSERT_OK(InitOp());
+ const int first_batch = 2;
+ const int first_height = 2;
+ const int first_width = 3;
+ const int first_depth = 1;
+ Tensor first_float(DT_FLOAT,
+ {first_batch, first_height, first_width, first_depth});
+ test::FillValues<float>(&first_float,
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+ Tensor first_quantized =
+ FloatTensorToQuantized<quint8>(first_float, first_min, first_max);
+
+ const int second_batch = 2;
+ const int second_height = 2;
+ const int second_width = 3;
+ const int second_depth = 1;
+
+ Tensor second_float(
+ DT_FLOAT, {second_batch, second_height, second_width, second_depth});
+ test::FillValues<float>(&second_float,
+ {13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24});
+ Tensor second_quantized =
+ FloatTensorToQuantized<quint8>(second_float, second_min, second_max);
+
+ const int expected_height = first_height + second_height;
+ Tensor expected_float(
+ DT_FLOAT, {first_batch, expected_height, first_width, first_depth});
+ test::FillValues<float>(&expected_float,
+ {1, 2, 3, 4, 5, 6, 13, 14, 15, 16, 17, 18,
+ 7, 8, 9, 10, 11, 12, 19, 20, 21, 22, 23, 24});
+
+ AddInputFromArray<quint8>(first_quantized.shape(),
+ first_quantized.flat<quint8>());
+ AddInputFromArray<quint8>(second_quantized.shape(),
+ second_quantized.flat<quint8>());
+ AddInputFromArray<int32>(TensorShape({}), {1});
+ AddInputFromArray<float>(TensorShape({}), {first_min});
+ AddInputFromArray<float>(TensorShape({}), {second_min});
+ AddInputFromArray<float>(TensorShape({}), {first_max});
+ AddInputFromArray<float>(TensorShape({}), {second_max});
+ AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
+ AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
+ AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
+ AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
+ TF_ASSERT_OK(RunOpKernel());
+ const Tensor& output_quantized = *GetOutput(0);
+ const float output_min = GetOutput(1)->flat<float>()(0);
+ const float output_max = GetOutput(2)->flat<float>()(0);
+ Tensor output_float =
+ QuantizedTensorToFloat<quint8>(output_quantized, output_min, output_max);
+ // Using the same error tolerance as in Eigen QuantizedConcat test
+ test::ExpectTensorNear<float>(expected_float, output_float, 1.0);
+}
+
+} // namespace tensorflow
+
+#endif // INTEL_MKL && ENABLE_MKL
diff --git a/tensorflow/core/kernels/mutex_ops.cc b/tensorflow/core/kernels/mutex_ops.cc
index 1603a2a..b06845f 100644
--- a/tensorflow/core/kernels/mutex_ops.cc
+++ b/tensorflow/core/kernels/mutex_ops.cc
@@ -242,10 +242,24 @@
REGISTER_KERNEL_BUILDER(Name("MutexLock").Device(DEVICE_CPU), MutexLockOp);
-REGISTER_KERNEL_BUILDER(Name("MutexV2").Device(DEVICE_CPU),
+REGISTER_KERNEL_BUILDER(Name("MutexLock")
+ .Device(DEVICE_GPU)
+ .HostMemory("mutex_lock")
+ .HostMemory("mutex"),
+ MutexLockOp);
+
+REGISTER_KERNEL_BUILDER(
+ Name("MutexV2").Device(DEVICE_CPU).HostMemory("resource"),
+ ResourceHandleOp<Mutex>);
+
+REGISTER_KERNEL_BUILDER(Name("MutexV2").Device(DEVICE_GPU),
ResourceHandleOp<Mutex>);
REGISTER_KERNEL_BUILDER(Name("ConsumeMutexLock").Device(DEVICE_CPU),
ConsumeMutexLockOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ConsumeMutexLock").Device(DEVICE_GPU).HostMemory("mutex_lock"),
+ ConsumeMutexLockOp);
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/pack_op.cc b/tensorflow/core/kernels/pack_op.cc
index 5645275..18ed1ea 100644
--- a/tensorflow/core/kernels/pack_op.cc
+++ b/tensorflow/core/kernels/pack_op.cc
@@ -158,7 +158,8 @@
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
TF_CALL_bfloat16(REGISTER_GPU);
TF_CALL_int64(REGISTER_GPU);
-REGISTER_GPU(bool);
+TF_CALL_int16(REGISTER_GPU);
+TF_CALL_bool(REGISTER_GPU);
#undef REGISTER_GPU
// A special GPU kernel for int32.
diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc
index 5d26265..d8c38fe 100644
--- a/tensorflow/core/kernels/partitioned_function_ops.cc
+++ b/tensorflow/core/kernels/partitioned_function_ops.cc
@@ -157,25 +157,18 @@
Status Instantiate(FunctionLibraryRuntime* lib, OpKernelContext* ctx,
std::vector<Tensor>* inputs,
FunctionLibraryRuntime::Handle* handle) {
- // We are going to execute the graph via function library runtime, and
- // because function execution semantics is slightly different from the
- // regular tensorlow graph, we need to make sure that Grappler respects it
- // when doing it's optimization passes (e.g. do not prune stateful and
- // dataset ops).
grappler::GrapplerItem::OptimizationOptions optimization_options;
- optimization_options.is_function_instantiation = true;
- // Keras graphs expected to be executed with regular graph execution
- // semantics (it's allowed to prune stateful and dataset ops).
- if (absl::StrContains(func_.name(), "keras_graph")) {
- optimization_options.is_function_instantiation = false;
- }
+ // Tensorflow 2.0 in eager mode with automatic control dependencies will
+ // prune all nodes that are not in the transitive fanin of the fetch nodes.
+ // However because the function will be executed via FunctionLibraryRuntime,
+ // and current function implementation does not prune stateful and dataset
+ // ops, we rely on Grappler to do the correct graph pruning.
+ optimization_options.allow_pruning_stateful_and_dataset_ops = true;
- // Wrapped function expects execution semantics to be the same as
- // `session.run`, so we should prune unreachable stateful and dataset ops.
- if (absl::StrContains(func_.name(), "wrapped_function")) {
- optimization_options.is_function_instantiation = false;
- }
+ // All the nested function calls will be executed and optimized via
+ // PartitionedCallOp, there is no need to optimize functions now.
+ optimization_options.optimize_function_library = false;
FunctionLibraryRuntime::InstantiateOptions opts;
opts.target = lib->device()->name();
@@ -183,7 +176,7 @@
opts.optimize_graph_fn = std::bind(
grappler::OptimizeGraph, std::placeholders::_1, std::placeholders::_2,
std::placeholders::_3, std::placeholders::_4, config_proto_,
- optimization_options, std::placeholders::_5);
+ func_.name(), optimization_options, std::placeholders::_5);
opts.graph_collector = ctx->graph_collector();
opts.executor_type = executor_type_;
diff --git a/tensorflow/core/kernels/quantized_concat_op.cc b/tensorflow/core/kernels/quantized_concat_op.cc
index b03ac8e..ff4e7be 100644
--- a/tensorflow/core/kernels/quantized_concat_op.cc
+++ b/tensorflow/core/kernels/quantized_concat_op.cc
@@ -246,4 +246,16 @@
#undef REGISTER_QUANTIZED_CONCAT
+#ifdef INTEL_MKL
+#define REGISTER_QUANTIZED_CONCATV2(type) \
+ REGISTER_KERNEL_BUILDER(Name("QuantizedConcatV2") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("axis"), \
+ QuantizedConcatOp<type>)
+
+REGISTER_QUANTIZED_CONCATV2(quint8);
+REGISTER_QUANTIZED_CONCATV2(qint32);
+#endif
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/sdca_internal.cc b/tensorflow/core/kernels/sdca_internal.cc
index 2bb2c0d..cbc754a 100644
--- a/tensorflow/core/kernels/sdca_internal.cc
+++ b/tensorflow/core/kernels/sdca_internal.cc
@@ -310,7 +310,10 @@
void Examples::RandomShuffle() {
std::iota(sampled_index_.begin(), sampled_index_.end(), 0);
- std::random_shuffle(sampled_index_.begin(), sampled_index_.end());
+
+ std::random_device rd;
+ std::mt19937 rng(rd());
+ std::shuffle(sampled_index_.begin(), sampled_index_.end(), rng);
}
// TODO(sibyl-Aix6ihai): Refactor/shorten this function.
diff --git a/tensorflow/core/lib/bfloat16/bfloat16.cc b/tensorflow/core/lib/bfloat16/bfloat16.cc
index a591717..e6e24bc 100644
--- a/tensorflow/core/lib/bfloat16/bfloat16.cc
+++ b/tensorflow/core/lib/bfloat16/bfloat16.cc
@@ -19,6 +19,9 @@
namespace tensorflow {
+const uint16_t bfloat16::NAN_VALUE;
+const uint16_t bfloat16::ZERO_VALUE;
+
B16_DEVICE_FUNC bfloat16::operator Eigen::half() const {
return static_cast<Eigen::half>(float(*this));
}
diff --git a/tensorflow/core/lib/gif/gif_io.cc b/tensorflow/core/lib/gif/gif_io.cc
index ce842e9..dc54069 100644
--- a/tensorflow/core/lib/gif/gif_io.cc
+++ b/tensorflow/core/lib/gif/gif_io.cc
@@ -140,6 +140,10 @@
ColorMapObject* color_map = this_image->ImageDesc.ColorMap
? this_image->ImageDesc.ColorMap
: gif_file->SColorMap;
+ if (color_map == nullptr) {
+ *error_string = strings::StrCat("missing color map for frame ", k);
+ return nullptr;
+ }
for (int i = imgTop; i < imgBottom; ++i) {
uint8* p_dst = this_dst + i * width * channel;
diff --git a/tensorflow/core/lib/strings/proto_serialization.cc b/tensorflow/core/lib/strings/proto_serialization.cc
index 5c1fbda..a6c321c 100644
--- a/tensorflow/core/lib/strings/proto_serialization.cc
+++ b/tensorflow/core/lib/strings/proto_serialization.cc
@@ -14,20 +14,61 @@
==============================================================================*/
#include "tensorflow/core/lib/strings/proto_serialization.h"
+#include <cstring>
+#include "absl/memory/memory.h"
+#include "absl/strings/string_view.h"
+#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
bool SerializeToStringDeterministic(const protobuf::MessageLite& msg,
string* result) {
- DCHECK_LE(msg.ByteSizeLong(), static_cast<size_t>(INT_MAX));
- const int size = static_cast<int>(msg.ByteSizeLong());
+ const size_t size = msg.ByteSizeLong();
+ DCHECK_LE(size, static_cast<size_t>(INT_MAX));
*result = string(size, '\0');
- protobuf::io::ArrayOutputStream array_stream(&(*result)[0], size);
+ return SerializeToBufferDeterministic(msg, const_cast<char*>(result->data()),
+ result->size());
+}
+
+bool SerializeToBufferDeterministic(const protobuf::MessageLite& msg,
+ char* buffer, size_t size) {
+ DCHECK(msg.ByteSizeLong() == size && size <= static_cast<size_t>(INT_MAX));
+ protobuf::io::ArrayOutputStream array_stream(buffer, size);
protobuf::io::CodedOutputStream output_stream(&array_stream);
output_stream.SetSerializationDeterministic(true);
msg.SerializeWithCachedSizes(&output_stream);
return !output_stream.HadError() && size == output_stream.ByteCount();
}
+bool AreSerializedProtosEqual(const protobuf::MessageLite& x,
+ const protobuf::MessageLite& y) {
+ const size_t size = x.ByteSizeLong();
+ if (size != y.ByteSizeLong()) return false;
+ if (size == 0) return true;
+ auto x_serialized = absl::make_unique<char[]>(size);
+ bool success_x = SerializeToBufferDeterministic(x, x_serialized.get(), size);
+ DCHECK(success_x);
+ auto y_serialized = absl::make_unique<char[]>(size);
+ bool success_y = SerializeToBufferDeterministic(y, y_serialized.get(), size);
+ DCHECK(success_y);
+ return memcmp(x_serialized.get(), y_serialized.get(), size) == 0;
+}
+
+uint64 DeterministicProtoHash64(const protobuf::MessageLite& proto,
+ uint64 seed) {
+ const size_t size = proto.ByteSizeLong();
+ auto serialized = absl::make_unique<char[]>(size);
+ SerializeToBufferDeterministic(proto, serialized.get(), size);
+ return Hash64(serialized.get(), size, seed);
+}
+
+uint64 DeterministicProtoHash64(const protobuf::MessageLite& proto) {
+ const size_t size = proto.ByteSizeLong();
+ auto serialized = absl::make_unique<char[]>(size);
+ SerializeToBufferDeterministic(proto, serialized.get(), size);
+ return Hash64(serialized.get(), size);
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/lib/strings/proto_serialization.h b/tensorflow/core/lib/strings/proto_serialization.h
index 6664928..763bd68 100644
--- a/tensorflow/core/lib/strings/proto_serialization.h
+++ b/tensorflow/core/lib/strings/proto_serialization.h
@@ -28,6 +28,21 @@
bool SerializeToStringDeterministic(const protobuf::MessageLite& msg,
string* result);
+// As above, but takes a pre-allocated buffer wrapped by result.
+// PRECONDITION: size == msg.ByteSizeLong() && size <= INT_MAX.
+bool SerializeToBufferDeterministic(const protobuf::MessageLite& msg,
+ char* buffer, size_t size);
+
+// Returns true if serializing x and y using
+// SerializeToBufferDeterministic() yields identical strings.
+bool AreSerializedProtosEqual(const protobuf::MessageLite& x,
+ const protobuf::MessageLite& y);
+
+// Computes Hash64 of the output of SerializeToBufferDeterministic().
+uint64 DeterministicProtoHash64(const protobuf::MessageLite& proto);
+uint64 DeterministicProtoHash64(const protobuf::MessageLite& proto,
+ uint64 seed);
+
} // namespace tensorflow
#endif // TENSORFLOW_CORE_LIB_STRINGS_PROTO_SERIALIZATION_H_
diff --git a/tensorflow/core/lib/strings/proto_serialization_test.cc b/tensorflow/core/lib/strings/proto_serialization_test.cc
new file mode 100644
index 0000000..cbde20e
--- /dev/null
+++ b/tensorflow/core/lib/strings/proto_serialization_test.cc
@@ -0,0 +1,66 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/lib/strings/proto_serialization.h"
+
+#include <string>
+#include "absl/memory/memory.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+
+static void BM_ProtoSerializationToString(int iters, int num_nodes) {
+ testing::StopTiming();
+ GraphDef graph_def;
+ for (int i = 0; i < num_nodes; ++i) {
+ NodeDef* node = graph_def.add_node();
+ node->set_name(strings::StrCat("node", i));
+ node->set_op(strings::StrCat("op", i % 10));
+ }
+ testing::StartTiming();
+ for (int i = 0; i < iters; ++i) {
+ string serialized;
+ testing::DoNotOptimize(
+ SerializeToStringDeterministic(graph_def, &serialized));
+ }
+ testing::StopTiming();
+}
+BENCHMARK(BM_ProtoSerializationToString)->Range(1, 10000);
+
+static void BM_ProtoSerializationToBuffer(int iters, int num_nodes) {
+ testing::StopTiming();
+ GraphDef graph_def;
+ for (int i = 0; i < num_nodes; ++i) {
+ NodeDef* node = graph_def.add_node();
+ node->set_name(strings::StrCat("node", i));
+ node->set_op(strings::StrCat("op", i % 10));
+ }
+ testing::StartTiming();
+ const size_t size = graph_def.ByteSizeLong();
+ for (int i = 0; i < iters; ++i) {
+ auto buf = absl::make_unique<char[]>(size);
+ testing::DoNotOptimize(
+ SerializeToBufferDeterministic(graph_def, buf.get(), size));
+ }
+ testing::StopTiming();
+}
+
+BENCHMARK(BM_ProtoSerializationToBuffer)->Range(1, 10000);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index f678957..7e5b448 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -23562,6 +23562,33 @@
is_stateful: true
}
op {
+ name: "ExperimentalRebatchDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "num_workers"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "ExperimentalScanDataset"
input_arg {
name: "input_dataset"
@@ -64526,6 +64553,37 @@
}
}
op {
+ name: "ShardDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "num_shards"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "index"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "ShardedFilename"
input_arg {
name: "basename"
@@ -79119,6 +79177,43 @@
}
}
op {
+ name: "TensorListConcatV2"
+ input_arg {
+ name: "input_handle"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "element_shape"
+ type_attr: "shape_type"
+ }
+ input_arg {
+ name: "leading_dims"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "tensor"
+ type_attr: "element_dtype"
+ }
+ output_arg {
+ name: "lengths"
+ type: DT_INT64
+ }
+ attr {
+ name: "element_dtype"
+ type: "type"
+ }
+ attr {
+ name: "shape_type"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "TensorListElementShape"
input_arg {
name: "input_handle"
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 1c11716..872a6da 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -27,7 +27,7 @@
// to a stateful "iterator" by passing the "dataset" to the
// "MakeIterator" op.
//
-// TODO(b/65524810): DT_VARIANT tensors that represent "dataset" objects are
+// TODO(b/123753214): DT_VARIANT tensors that represent "dataset" objects are
// not presently serializable. To avoid issues with constant folding, ensure
// that any "source dataset" ops (i.e. ops that output a dataset and do not
// take one as input) are marked "stateful".
@@ -37,7 +37,7 @@
.Output("handle: variant")
.Attr("Toutput_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ .SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetShapeFn(shape_inference::ScalarShape); // TODO(mrry): Validate that
// `components` have shapes
@@ -49,7 +49,7 @@
.Output("handle: variant")
.Attr("Toutput_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ .SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetShapeFn(shape_inference::ScalarShape); // TODO(mrry): Validate that the
// dim-0 slices of `components`
@@ -62,7 +62,7 @@
.Input("dense_shape: int64")
.Output("handle: variant")
.Attr("Tvalues: type")
- .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ .SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetShapeFn(shape_inference::ScalarShape);
@@ -79,7 +79,7 @@
.Attr("Tfinalize_func_args: list(type) >= 0")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ .SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetShapeFn(shape_inference::ScalarShape);
@@ -275,6 +275,22 @@
return shape_inference::ScalarShape(c);
});
+REGISTER_OP("ShardDataset")
+ .Input("input_dataset: variant")
+ .Input("num_shards: int64")
+ .Input("index: int64")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // num_shards should be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ // index should be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ return shape_inference::ScalarShape(c);
+ });
+
// TODO(mrry): Validate that `padded_shapes` are all vectors, the lengths of
// `output_types` and `output_shapes` are `N` the `output_shapes` are (as far as
// possible to tell statically) compatible with `padded_shapes`, and that
@@ -322,7 +338,7 @@
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ .SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
@@ -388,7 +404,7 @@
.Input("compression_type: string")
.Input("buffer_size: int64")
.Output("handle: variant")
- .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ .SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
@@ -408,7 +424,7 @@
.Input("footer_bytes: int64")
.Input("buffer_size: int64")
.Output("handle: variant")
- .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ .SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
@@ -431,7 +447,7 @@
.Input("buffer_size: int64")
.Input("compression_type: string")
.Output("handle: variant")
- .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ .SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
@@ -451,7 +467,7 @@
.Input("compression_type: string")
.Input("buffer_size: int64")
.Output("handle: variant")
- .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ .SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
diff --git a/tensorflow/core/ops/experimental_dataset_ops.cc b/tensorflow/core/ops/experimental_dataset_ops.cc
index 316e405..95230af 100644
--- a/tensorflow/core/ops/experimental_dataset_ops.cc
+++ b/tensorflow/core/ops/experimental_dataset_ops.cc
@@ -42,7 +42,7 @@
.Output("handle: variant")
.Attr("output_types: list({float,double,int32,int64,string}) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ .SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
@@ -190,6 +190,14 @@
return shape_inference::ScalarShape(c);
});
+REGISTER_OP("ExperimentalRebatchDataset")
+ .Input("input_dataset: variant")
+ .Input("num_workers: int64")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
REGISTER_OP("ExperimentalMapDataset")
.Input("input_dataset: variant")
.Input("other_arguments: Targuments")
@@ -205,7 +213,7 @@
REGISTER_OP("ExperimentalMatchingFilesDataset")
.Input("patterns: string")
.Output("handle: variant")
- .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ .SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
@@ -259,7 +267,7 @@
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ .SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
@@ -330,7 +338,7 @@
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ .SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
@@ -459,7 +467,7 @@
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ .SetIsStateful() // TODO(b/123753214): Source dataset ops must be marked
// stateful to inhibit constant folding.
.SetShapeFn(shape_inference::ScalarShape);
diff --git a/tensorflow/core/ops/list_ops.cc b/tensorflow/core/ops/list_ops.cc
index fdaa5a2..f8fdb10 100644
--- a/tensorflow/core/ops/list_ops.cc
+++ b/tensorflow/core/ops/list_ops.cc
@@ -209,6 +209,42 @@
return Status::OK();
});
+Status TensorListConcatShapeInference(
+ shape_inference::InferenceContext* c,
+ shape_inference::ShapeHandle element_shape) {
+ DataType element_dtype;
+ TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
+ auto* handle_data = c->input_handle_shapes_and_types(0);
+ if (handle_data != nullptr && handle_data->size() > 1) {
+ return errors::InvalidArgument(
+ "Trying to read from list with wrong variant data.");
+ }
+ if (handle_data != nullptr && handle_data->size() == 1) {
+ const shape_inference::ShapeAndType& list_shape_type = (*handle_data)[0];
+ if (list_shape_type.dtype != element_dtype) {
+ return errors::InvalidArgument(
+ "Trying to read from list with wrong element dtype. List has "
+ "type ",
+ DataTypeString(list_shape_type.dtype), " but expected type ",
+ DataTypeString(element_dtype));
+ }
+ shape_inference::ShapeHandle merged;
+ TF_RETURN_IF_ERROR(c->Merge(element_shape, list_shape_type.shape, &merged));
+ element_shape = merged;
+ }
+ if (c->RankKnown(element_shape)) {
+ shape_inference::ShapeHandle result;
+ TF_RETURN_IF_ERROR(c->Subshape(element_shape, 1, &result));
+ TF_RETURN_IF_ERROR(
+ c->Concatenate(c->MakeShape({c->UnknownDim()}), result, &result));
+ c->set_output(0, result);
+ } else {
+ c->set_output(0, c->UnknownShape());
+ }
+ c->set_output(1, c->MakeShape({c->UnknownDim()}));
+ return Status::OK();
+}
+
REGISTER_OP("TensorListConcat")
.Input("input_handle: variant")
.Output("tensor: element_dtype")
@@ -216,45 +252,27 @@
.Attr("element_dtype: type")
.Attr("element_shape: shape = { unknown_rank: true }")
.SetShapeFn([](shape_inference::InferenceContext* c) {
- DataType element_dtype;
- TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
PartialTensorShape raw_element_shape;
TF_RETURN_IF_ERROR(c->GetAttr("element_shape", &raw_element_shape));
shape_inference::ShapeHandle element_shape;
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(raw_element_shape,
&element_shape));
+ return TensorListConcatShapeInference(c, element_shape);
+ });
- auto* handle_data = c->input_handle_shapes_and_types(0);
- if (handle_data != nullptr && handle_data->size() > 1) {
- return errors::InvalidArgument(
- "Trying to read from list with wrong variant data.");
- }
- if (handle_data != nullptr && handle_data->size() == 1) {
- const shape_inference::ShapeAndType& list_shape_type =
- (*handle_data)[0];
- if (list_shape_type.dtype != element_dtype) {
- return errors::InvalidArgument(
- "Trying to read from list with wrong element dtype. List has "
- "type ",
- DataTypeString(list_shape_type.dtype), " but expected type ",
- DataTypeString(element_dtype));
- }
- shape_inference::ShapeHandle merged;
- TF_RETURN_IF_ERROR(
- c->Merge(element_shape, list_shape_type.shape, &merged));
- element_shape = merged;
- }
- if (c->RankKnown(element_shape)) {
- shape_inference::ShapeHandle result;
- TF_RETURN_IF_ERROR(c->Subshape(element_shape, 1, &result));
- TF_RETURN_IF_ERROR(
- c->Concatenate(c->MakeShape({c->UnknownDim()}), result, &result));
- c->set_output(0, result);
- } else {
- c->set_output(0, c->UnknownShape());
- }
- c->set_output(1, c->MakeShape({c->UnknownDim()}));
- return Status::OK();
+REGISTER_OP("TensorListConcatV2")
+ .Input("input_handle: variant")
+ .Input("element_shape: shape_type")
+ .Input("leading_dims: int64")
+ .Output("tensor: element_dtype")
+ .Output("lengths: int64")
+ .Attr("element_dtype: type")
+ .Attr("shape_type: {int32, int64}")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle element_shape;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
+ 1, &element_shape));
+ return TensorListConcatShapeInference(c, element_shape);
});
REGISTER_OP("TensorListSplit")
diff --git a/tensorflow/core/ops/mkl_array_ops.cc b/tensorflow/core/ops/mkl_array_ops.cc
new file mode 100644
index 0000000..e7ad3be
--- /dev/null
+++ b/tensorflow/core/ops/mkl_array_ops.cc
@@ -0,0 +1,92 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef INTEL_MKL
+
+// This file contains the registration of MKL-DNN array ops.
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/util/mirror_pad_mode.h"
+#include "tensorflow/core/util/padding.h"
+#include "tensorflow/core/util/strided_slice_op.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+namespace tensorflow {
+
+using shape_inference::DimensionHandle;
+using shape_inference::InferenceContext;
+using shape_inference::ShapeHandle;
+using shape_inference::UnchangedShape;
+
+// Adding QuantizedConcatV2 op to be able to replace it by
+// _MklQuantizedConcatV2 in the graph rewrite.
+REGISTER_OP("QuantizedConcatV2")
+ .Input("values: N * T")
+ .Input("axis: Tidx")
+ .Input("input_mins: N * float32")
+ .Input("input_maxes: N * float32")
+ .Output("output: T")
+ .Output("output_min: float")
+ .Output("output_max: float")
+ .Attr("N: int >= 2")
+ .Attr("T: type")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
+ .SetShapeFn([](InferenceContext* c) {
+ const int n = (c->num_inputs() - 1) / 3;
+ TF_RETURN_IF_ERROR(shape_inference::QuantizedConcatV2Shape(c, n));
+ ShapeHandle unused;
+ for (int i = n + 1; i < c->num_inputs(); ++i) {
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused));
+ }
+ c->set_output(1, c->Scalar());
+ c->set_output(2, c->Scalar());
+ return Status::OK();
+ });
+
+REGISTER_OP("_MklQuantizedConcatV2")
+ .Input("values: N * T")
+ .Input("axis: Tidx")
+ .Input("input_mins: N * float32")
+ .Input("input_maxes: N * float32")
+ .Input("mkl_values: N * uint8")
+ .Input("mkl_axis: uint8")
+ .Input("mkl_input_mins: N * uint8")
+ .Input("mkl_input_maxes: N * uint8")
+ .Output("output: T")
+ .Output("output_min: float")
+ .Output("output_max: float")
+ .Output("mkl_output: uint8")
+ .Output("mkl_output_min: uint8")
+ .Output("mkl_output_max: uint8")
+ .Attr("N: int >= 2")
+ .Attr("T: type")
+ .Attr("Tidx: {int32, int64} = DT_INT32")
+ .SetShapeFn([](InferenceContext* c) {
+ const int n = (c->num_inputs() / 2 - 1) / 3;
+ TF_RETURN_IF_ERROR(shape_inference::QuantizedConcatV2Shape(c, n));
+ ShapeHandle unused;
+ for (int i = n + 1; i < c->num_inputs() / 2; ++i) {
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused));
+ }
+ c->set_output(1, c->Scalar());
+ c->set_output(2, c->Scalar());
+ return Status::OK();
+ });
+} // namespace tensorflow
+
+#endif
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 7fd8a36..c0ec654 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -11327,6 +11327,33 @@
is_stateful: true
}
op {
+ name: "ExperimentalRebatchDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "num_workers"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "ExperimentalScanDataset"
input_arg {
name: "input_dataset"
@@ -31696,6 +31723,37 @@
}
}
op {
+ name: "ShardDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "num_shards"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "index"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "ShardedFilename"
input_arg {
name: "basename"
@@ -38347,6 +38405,43 @@
}
}
op {
+ name: "TensorListConcatV2"
+ input_arg {
+ name: "input_handle"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "element_shape"
+ type_attr: "shape_type"
+ }
+ input_arg {
+ name: "leading_dims"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "tensor"
+ type_attr: "element_dtype"
+ }
+ output_arg {
+ name: "lengths"
+ type: DT_INT64
+ }
+ attr {
+ name: "element_dtype"
+ type: "type"
+ }
+ attr {
+ name: "shape_type"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "TensorListElementShape"
input_arg {
name: "input_handle"
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index e949008..f9e3e1d 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -579,7 +579,10 @@
return []
def tf_additional_device_tracer_deps():
- return []
+ return [
+ "//tensorflow/core/profiler/lib:traceme",
+ "//tensorflow/core/profiler/internal/cpu:host_tracer",
+ ]
def tf_additional_device_tracer_test_flags():
return []
@@ -734,7 +737,6 @@
return ["@nsync//:nsync_cpp"] + if_cuda(
[
"//tensorflow/stream_executor:cuda_platform",
- "//tensorflow/core/platform/default/build_config:cuda",
],
) + if_rocm(
[
diff --git a/tensorflow/core/platform/default/build_config/BUILD b/tensorflow/core/platform/default/build_config/BUILD
index 6faf5c5..845fe0e 100644
--- a/tensorflow/core/platform/default/build_config/BUILD
+++ b/tensorflow/core/platform/default/build_config/BUILD
@@ -7,6 +7,7 @@
exports_files(["LICENSE"])
+load("//tensorflow:tensorflow.bzl", "check_deps")
load("//tensorflow:tensorflow.bzl", "if_cuda")
load("//tensorflow:tensorflow.bzl", "if_rocm")
load("//tensorflow:tensorflow.bzl", "tf_copts")
@@ -283,6 +284,20 @@
],
)
+# Check that libtensorflow_framework.so does not depend on cuda shared libraries.
+check_deps(
+ name = "libtensorflow_cuda_check_deps",
+ disallowed_deps = [
+ ":cuda",
+ "@local_config_cuda//cuda:cublas",
+ "@local_config_cuda//cuda:cuda_driver",
+ "@local_config_cuda//cuda:cudnn",
+ "@local_config_cuda//cuda:curand",
+ "@local_config_cuda//cuda:cusolver",
+ ],
+ deps = ["//tensorflow:libtensorflow_framework.so"],
+)
+
cc_library(
name = "rocm",
data = [],
diff --git a/tensorflow/core/platform/default/device_tracer.cc b/tensorflow/core/platform/default/device_tracer.cc
index 8351362..ffcb38f 100644
--- a/tensorflow/core/platform/default/device_tracer.cc
+++ b/tensorflow/core/platform/default/device_tracer.cc
@@ -31,6 +31,8 @@
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/tracing.h"
+#include "tensorflow/core/profiler/internal/cpu/host_tracer.h"
+#include "tensorflow/core/profiler/lib/traceme.h"
namespace {
@@ -299,6 +301,14 @@
class TraceCollectorImpl : public tracing::TraceCollector {
public:
+ class ActivityHandle : public Handle {
+ public:
+ ActivityHandle(string &&name, int level)
+ : trace_me_(std::move(name), level) {}
+
+ private:
+ profiler::TraceMe trace_me_;
+ };
TraceCollectorImpl() { tracing::SetTraceCollector(this); }
~TraceCollectorImpl() override {
@@ -318,14 +328,16 @@
}
~Impl() override { tls_current_annotation.get() = nullptr; }
};
- return std::unique_ptr<Handle>(
- new Impl{ConcatenateNames(name_part1, name_part2)});
+ return absl::make_unique<Impl>(ConcatenateNames(name_part1, name_part2));
}
- virtual std::unique_ptr<Handle> CreateActivityHandle(StringPiece, StringPiece,
- bool) const {
- // We don't do anything with 'Activities' yet.
- return nullptr;
+ virtual std::unique_ptr<Handle> CreateActivityHandle(
+ StringPiece name_part1, StringPiece name_part2, bool is_expensive) const {
+ if (!IsEnabledForActivities(is_expensive)) {
+ return nullptr;
+ }
+ return absl::make_unique<ActivityHandle>(
+ ConcatenateNames(name_part1, name_part2), GetLevel(is_expensive));
}
bool IsEnabledForAnnotations() const override {
@@ -333,8 +345,7 @@
}
bool IsEnabledForActivities(bool is_expensive) const override {
- // We don't do anything with 'Activities' so we are never 'enabled'.
- return false;
+ return profiler::TraceMeRecorder::Active(GetLevel(is_expensive));
}
void Start() {
@@ -349,6 +360,10 @@
}
private:
+ static int GetLevel(bool is_expensive) {
+ return profiler::GetTFTraceMeLevel(is_expensive);
+ }
+
std::atomic<bool> active_trace_session_;
};
@@ -421,6 +436,7 @@
int64 end_walltime_us_ GUARDED_BY(mu_);
uint64_t start_timestamp_ GUARDED_BY(mu_);
uint64_t end_timestamp_ GUARDED_BY(mu_);
+ std::unique_ptr<profiler::cpu::HostTracer> host_tracer_ GUARDED_BY(mu_);
TF_DISALLOW_COPY_AND_ASSIGN(DeviceTracerImpl);
};
@@ -429,6 +445,7 @@
: cupti_manager_(cupti_manager) {
VLOG(1) << "DeviceTracer created.";
cupti_wrapper_.reset(new perftools::gputools::profiler::CuptiWrapper());
+ host_tracer_ = profiler::cpu::HostTracer::Create(2);
enabled_ = false;
}
@@ -493,6 +510,7 @@
CUPTI_CALL(GetTimestamp(&start_timestamp_));
start_walltime_us_ = NowInUsec();
+ host_tracer_->Start().IgnoreError();
enabled_ = true;
return Status::OK();
}
@@ -510,6 +528,7 @@
end_walltime_us_ = NowInUsec();
CUPTI_CALL(GetTimestamp(&end_timestamp_));
enabled_ = false;
+ host_tracer_->Stop().IgnoreError();
return Status::OK();
}
@@ -676,6 +695,8 @@
collector->Save(memcpy_device, ns);
collector->Save(strings::StrCat(stream_device, rec.stream_id), nscopy);
}
+
+ host_tracer_->CollectDataToCollector(collector).IgnoreError();
return Status::OK();
}
diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system.cc b/tensorflow/core/platform/hadoop/hadoop_file_system.cc
index 4d1826c..65cb848 100644
--- a/tensorflow/core/platform/hadoop/hadoop_file_system.cc
+++ b/tensorflow/core/platform/hadoop/hadoop_file_system.cc
@@ -209,8 +209,12 @@
// We lock inside the loop rather than outside so we don't block other
// concurrent readers.
mutex_lock lock(mu_);
+ // Max read length is INT_MAX-2, for hdfsPread function take a parameter
+ // of int32. -2 offset can avoid JVM OutOfMemoryError.
+ size_t read_n =
+ std::min(n, static_cast<size_t>(std::numeric_limits<int>::max() - 2));
tSize r = hdfs_->hdfsPread(fs_, file_, static_cast<tOffset>(offset), dst,
- static_cast<tSize>(n));
+ static_cast<tSize>(read_n));
if (r > 0) {
dst += r;
n -= r;
diff --git a/tensorflow/core/platform/logging.h b/tensorflow/core/platform/logging.h
index 17a5d5f..7417ec8 100644
--- a/tensorflow/core/platform/logging.h
+++ b/tensorflow/core/platform/logging.h
@@ -19,7 +19,7 @@
#include "tensorflow/core/platform/platform.h" // To pick up PLATFORM_define
#if defined(PLATFORM_GOOGLE) || defined(PLATFORM_GOOGLE_ANDROID) || \
- defined(GOOGLE_LOGGING)
+ defined(GOOGLE_LOGGING) || defined(__EMSCRIPTEN__)
#include "tensorflow/core/platform/google/build_config/logging.h"
#else
#include "tensorflow/core/platform/default/logging.h"
diff --git a/tensorflow/core/platform/platform.h b/tensorflow/core/platform/platform.h
index 0481b36..671e5dd 100644
--- a/tensorflow/core/platform/platform.h
+++ b/tensorflow/core/platform/platform.h
@@ -40,7 +40,7 @@
#elif defined(_WIN32)
#define PLATFORM_WINDOWS
-#elif defined(__arm__)
+#elif defined(__arm__) || defined(__EMSCRIPTEN__)
#define PLATFORM_POSIX
// Require an outside macro to tell us if we're building for Raspberry Pi or
diff --git a/tensorflow/core/profiler/internal/BUILD b/tensorflow/core/profiler/internal/BUILD
index 8dcfde9..da3039a 100644
--- a/tensorflow/core/profiler/internal/BUILD
+++ b/tensorflow/core/profiler/internal/BUILD
@@ -6,6 +6,8 @@
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow:tensorflow.bzl", "if_not_windows")
+load("//tensorflow:tensorflow.bzl", "tf_cuda_library")
+load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
cc_library(
name = "tfprof_stats",
@@ -365,3 +367,43 @@
"//tensorflow/core:regexp_internal",
],
)
+
+tf_cuda_library(
+ name = "traceme_recorder",
+ srcs = ["traceme_recorder.cc"],
+ hdrs = ["traceme_recorder.h"],
+ visibility = [
+ "//learning/brain/runtime:__pkg__", # xprof_bridge
+ "//perftools/accelerators/xprof/xprofilez:__pkg__", # alias xprof::TraceMeRecorder
+ "//tensorflow/core/profiler/internal/cpu:__pkg__", # host_tracer
+ "//tensorflow/core/profiler/lib:__pkg__", # traceme
+ ],
+ deps = [
+ "//tensorflow/core:lib",
+ "//tensorflow/stream_executor/lib",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/container:flat_hash_map",
+ ],
+)
+
+tf_cuda_cc_test(
+ name = "traceme_recorder_test",
+ srcs = ["traceme_recorder_test.cc"],
+ deps = [
+ ":traceme_recorder",
+ "//tensorflow/core:lib",
+ "@com_google_absl//absl/synchronization",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+tf_cuda_library(
+ name = "profiler_interface",
+ hdrs = [
+ "profiler_interface.h",
+ ],
+ deps = [
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
diff --git a/tensorflow/core/profiler/internal/cpu/BUILD b/tensorflow/core/profiler/internal/cpu/BUILD
new file mode 100644
index 0000000..b94453c
--- /dev/null
+++ b/tensorflow/core/profiler/internal/cpu/BUILD
@@ -0,0 +1,44 @@
+package(
+ default_visibility = ["//tensorflow:internal"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_cuda_library",
+)
+load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
+
+tf_cuda_library(
+ name = "host_tracer",
+ srcs = [
+ "host_tracer.cc",
+ ],
+ hdrs = [
+ "host_tracer.h",
+ ],
+ deps = [
+ "//tensorflow/core:core_cpu_lib",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/profiler/internal:profiler_interface",
+ "//tensorflow/core/profiler/internal:traceme_recorder",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+tf_cuda_cc_test(
+ name = "host_tracer_test",
+ srcs = ["host_tracer_test.cc"],
+ deps = [
+ ":host_tracer",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core/profiler/lib:traceme",
+ "@com_google_absl//absl/types:optional",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
diff --git a/tensorflow/core/profiler/internal/cpu/host_tracer.cc b/tensorflow/core/profiler/internal/cpu/host_tracer.cc
new file mode 100644
index 0000000..3fb2966
--- /dev/null
+++ b/tensorflow/core/profiler/internal/cpu/host_tracer.cc
@@ -0,0 +1,120 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/profiler/internal/cpu/host_tracer.h"
+
+#include <utility>
+
+#include "absl/container/flat_hash_map.h"
+#include "absl/strings/str_split.h"
+#include "tensorflow/core/common_runtime/step_stats_collector.h"
+#include "tensorflow/core/platform/env_time.h"
+
+namespace tensorflow {
+namespace profiler {
+namespace cpu {
+
+/* static */ std::unique_ptr<HostTracer> HostTracer::Create(
+ int host_trace_level) {
+ return absl::WrapUnique(new HostTracer(host_trace_level));
+}
+HostTracer::HostTracer(int host_trace_level)
+ : host_trace_level_(host_trace_level) {}
+
+HostTracer::~HostTracer() { Stop().IgnoreError(); }
+
+Status HostTracer::Start() {
+ if (recording_) {
+ return Status(error::INTERNAL, "TraceMeRecorder already started");
+ }
+ recording_ = TraceMeRecorder::Start(host_trace_level_);
+ if (!recording_) {
+ return Status(error::INTERNAL, "Failed to start TraceMeRecorder");
+ }
+ return Status::OK();
+}
+
+Status HostTracer::Stop() {
+ if (!recording_) {
+ return Status(error::INTERNAL, "TraceMeRecorder not started");
+ }
+ events_ = TraceMeRecorder::Stop();
+ recording_ = false;
+ return Status::OK();
+}
+
+constexpr char kUserMetadataMarker = '#';
+
+Status HostTracer::CollectData(RunMetadata* run_metadata) {
+ auto step_stats_collector =
+ absl::make_unique<StepStatsCollector>(run_metadata->mutable_step_stats());
+ return CollectDataToCollector(step_stats_collector.get());
+}
+
+Status HostTracer::CollectDataToCollector(
+ StepStatsCollector* step_stats_collector) {
+ if (events_.empty() && recording_) {
+ events_ = TraceMeRecorder::Collect();
+ }
+ // Pair up start and end events, and add complete events to trace_entries.
+ absl::flat_hash_map<uint64, uint64> end_times;
+ for (const auto& thread : events_) {
+ for (const auto& event : thread.events) {
+ if (event.end_time && !event.start_time) {
+ end_times.emplace(event.activity_id, event.end_time);
+ }
+ }
+ }
+
+ const string cpu_name = "/host:CPU";
+ for (auto& thread : events_) {
+ step_stats_collector->SaveThreadName(cpu_name, thread.thread.tid,
+ thread.thread.name);
+ for (auto& event : thread.events) {
+ if (!event.end_time) {
+ auto it = end_times.find(event.activity_id);
+ if (it != end_times.end()) event.end_time = it->second;
+ }
+ if (event.start_time && event.end_time) {
+ NodeExecStats* ns = new NodeExecStats;
+ if (event.name.back() != kUserMetadataMarker) {
+ ns->set_node_name(std::move(event.name));
+ } else {
+ // Expect the format will be "<name>#<metadata>#"
+ std::vector<absl::string_view> parts =
+ absl::StrSplit(event.name, kUserMetadataMarker);
+ if (parts.size() >= 2) {
+ ns->set_node_name(string(parts[0]));
+ ns->set_timeline_label(string(parts[1]));
+ } else {
+ ns->set_node_name(std::move(event.name));
+ }
+ }
+ ns->set_all_start_micros(event.start_time / EnvTime::kMicrosToNanos);
+ ns->set_all_end_rel_micros((event.end_time - event.start_time) /
+ EnvTime::kMicrosToNanos);
+ ns->set_thread_id(thread.thread.tid);
+ // TODO(fishx): Add thread name to RunMetadata
+ step_stats_collector->Save(cpu_name, ns);
+ }
+ }
+ }
+ events_.clear();
+ step_stats_collector->Finalize();
+ return Status::OK();
+}
+
+} // namespace cpu
+} // namespace profiler
+} // namespace tensorflow
diff --git a/tensorflow/core/profiler/internal/cpu/host_tracer.h b/tensorflow/core/profiler/internal/cpu/host_tracer.h
new file mode 100644
index 0000000..c6340c2
--- /dev/null
+++ b/tensorflow/core/profiler/internal/cpu/host_tracer.h
@@ -0,0 +1,67 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_CPU_HOST_TRACER_H_
+#define TENSORFLOW_CORE_PROFILER_INTERNAL_CPU_HOST_TRACER_H_
+
+#include "tensorflow/core/common_runtime/step_stats_collector.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/profiler/internal/profiler_interface.h"
+#include "tensorflow/core/profiler/internal/traceme_recorder.h"
+#include "tensorflow/core/protobuf/config.pb.h"
+
+namespace tensorflow {
+namespace profiler {
+namespace cpu {
+
+// Controls TraceMeRecorder and converts TraceMeRecorder::Events into
+// RunMetadata messages.
+//
+// Thread-safety: This class is go/thread-compatible.
+class HostTracer : public ProfilerInterface {
+ public:
+ static std::unique_ptr<HostTracer> Create(int host_trace_level);
+
+ ~HostTracer();
+
+ // Starts recording TraceMes.
+ Status Start() override;
+
+ // Stops recording TraceMes.
+ Status Stop() override;
+
+ // Populates user traces and thread names in response.
+ // The user traces and thread names are in no particular order.
+ Status CollectData(RunMetadata* run_metadata) override;
+
+ Status CollectDataToCollector(StepStatsCollector* step_stats_collector);
+
+ private:
+ explicit HostTracer(int host_trace_level);
+
+ // Level of host tracing.
+ const int host_trace_level_;
+
+ // True if currently recording.
+ bool recording_ = false;
+
+ // Container of all traced events.
+ TraceMeRecorder::Events events_;
+};
+
+} // namespace cpu
+} // namespace profiler
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_CPU_HOST_TRACER_H_
diff --git a/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc b/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc
new file mode 100644
index 0000000..51f9c6a
--- /dev/null
+++ b/tensorflow/core/profiler/internal/cpu/host_tracer_test.cc
@@ -0,0 +1,133 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/profiler/internal/cpu/host_tracer.h"
+
+#include <string>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/types/optional.h"
+#include "tensorflow/core/framework/step_stats.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/profiler/lib/traceme.h"
+#include "tensorflow/core/protobuf/config.pb.h"
+
+namespace tensorflow {
+namespace profiler {
+namespace cpu {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::Pair;
+using ::testing::UnorderedElementsAre;
+
+NodeExecStats MakeNodeStats(const string& name, uint64 thread_id,
+ const string& label = "") {
+ NodeExecStats ns;
+ ns.set_node_name(name);
+ ns.set_thread_id(thread_id);
+ if (!label.empty()) {
+ ns.set_timeline_label(label);
+ }
+ return ns;
+}
+
+class NodeStatsMatcher {
+ public:
+ explicit NodeStatsMatcher(const NodeExecStats& expected)
+ : expected_(expected) {}
+
+ bool MatchAndExplain(const NodeExecStats& p,
+ ::testing::MatchResultListener* /* listener */) const {
+ return p.node_name() == expected_.node_name() &&
+ p.thread_id() == expected_.thread_id() &&
+ p.timeline_label() == expected_.timeline_label();
+ }
+
+ void DescribeTo(::std::ostream* os) const { *os << expected_.DebugString(); }
+ void DescribeNegationTo(::std::ostream* os) const {
+ *os << "not equal to expected message: " << expected_.DebugString();
+ }
+
+ private:
+ const NodeExecStats expected_;
+};
+
+inline ::testing::PolymorphicMatcher<NodeStatsMatcher> EqualsNodeStats(
+ const NodeExecStats& expected) {
+ return ::testing::MakePolymorphicMatcher(NodeStatsMatcher(expected));
+}
+
+TEST(HostTracerTest, CollectsTraceMeEvents) {
+ uint32 thread_id = Env::Default()->GetCurrentThreadId();
+
+ auto tracer = HostTracer::Create(/*host_trace_level=*/1);
+
+ TF_ASSERT_OK(tracer->Start());
+ { TraceMe traceme("hello"); }
+ { TraceMe traceme("world"); }
+ { TraceMe traceme("contains#inside"); }
+ { TraceMe traceme("good#key1=value1#"); }
+ { TraceMe traceme("morning#key1=value1,key2=value2#"); }
+ { TraceMe traceme("incomplete#key1=value1,key2#"); }
+ TF_ASSERT_OK(tracer->Stop());
+
+ RunMetadata run_metadata;
+ TF_ASSERT_OK(tracer->CollectData(&run_metadata));
+
+ EXPECT_EQ(run_metadata.step_stats().dev_stats_size(), 1);
+ EXPECT_EQ(run_metadata.step_stats().dev_stats(0).node_stats_size(), 6);
+ EXPECT_THAT(
+ run_metadata.step_stats().dev_stats(0).node_stats(),
+ UnorderedElementsAre(
+ EqualsNodeStats(MakeNodeStats("hello", thread_id)),
+ EqualsNodeStats(MakeNodeStats("world", thread_id)),
+ EqualsNodeStats(MakeNodeStats("contains#inside", thread_id)),
+ EqualsNodeStats(MakeNodeStats("good", thread_id, "key1=value1")),
+ EqualsNodeStats(
+ MakeNodeStats("morning", thread_id, "key1=value1,key2=value2")),
+ EqualsNodeStats(
+ MakeNodeStats("incomplete", thread_id, "key1=value1,key2"))));
+}
+
+void ValidateResult(const RunMetadata& run_metadata, const string& trace_name) {
+ uint32 thread_id = Env::Default()->GetCurrentThreadId();
+
+ EXPECT_THAT(
+ run_metadata.step_stats().dev_stats(0).node_stats(),
+ ElementsAre(EqualsNodeStats(MakeNodeStats(trace_name, thread_id))));
+}
+
+TEST(HostTracerTest, CollectsTraceMeEventsBetweenTracing) {
+ auto tracer = HostTracer::Create(/*host_trace_level=*/1);
+ RunMetadata run_metadata;
+ RunMetadata run_metadata2;
+
+ TF_ASSERT_OK(tracer->Start());
+ { TraceMe traceme("hello"); }
+ TF_ASSERT_OK(tracer->CollectData(&run_metadata));
+ { TraceMe traceme("world"); }
+ TF_ASSERT_OK(tracer->CollectData(&run_metadata2));
+ TF_ASSERT_OK(tracer->Stop());
+
+ ValidateResult(run_metadata, "hello");
+ ValidateResult(run_metadata2, "world");
+}
+
+} // namespace
+} // namespace cpu
+} // namespace profiler
+} // namespace tensorflow
diff --git a/tensorflow/core/profiler/internal/gpu/BUILD b/tensorflow/core/profiler/internal/gpu/BUILD
new file mode 100644
index 0000000..35f90e9
--- /dev/null
+++ b/tensorflow/core/profiler/internal/gpu/BUILD
@@ -0,0 +1,25 @@
+package(
+ default_visibility = ["//tensorflow:internal"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_cuda_library",
+)
+
+tf_cuda_library(
+ name = "tracer",
+ srcs = [
+ "tracer.cc",
+ ],
+ hdrs = [
+ "tracer.h",
+ ],
+ deps = [
+ "//tensorflow/core:core_cpu_lib",
+ "//tensorflow/core:device_tracer",
+ "//tensorflow/core/profiler/internal:profiler_interface",
+ ],
+)
diff --git a/tensorflow/core/profiler/internal/gpu/tracer.cc b/tensorflow/core/profiler/internal/gpu/tracer.cc
new file mode 100644
index 0000000..f1cb541
--- /dev/null
+++ b/tensorflow/core/profiler/internal/gpu/tracer.cc
@@ -0,0 +1,59 @@
+/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/profiler/internal/gpu/tracer.h"
+#include "tensorflow/core/common_runtime/step_stats_collector.h"
+
+namespace tensorflow {
+namespace profiler {
+namespace gpu {
+
+/* static */ std::unique_ptr<ProfilerInterface> Tracer::Create() {
+ return absl::WrapUnique(new Tracer());
+}
+
+Status Tracer::Start() {
+ device_tracer_ = CreateDeviceTracer();
+ if (!device_tracer_) {
+ return Status(tensorflow::error::Code::FAILED_PRECONDITION,
+ "Failed to create device tracer.");
+ }
+ return device_tracer_->Start();
+}
+
+Status Tracer::Stop() {
+ if (!device_tracer_) {
+ return Status(tensorflow::error::Code::FAILED_PRECONDITION,
+ "No running device tracer.");
+ }
+ return device_tracer_->Stop();
+}
+
+Status Tracer::CollectData(RunMetadata* run_metadata) {
+ if (!device_tracer_) {
+ return Status(tensorflow::error::Code::FAILED_PRECONDITION,
+ "No running device tracer.");
+ }
+ auto step_stats_collector =
+ absl::make_unique<StepStatsCollector>(run_metadata->mutable_step_stats());
+ Status s = device_tracer_->Collect(step_stats_collector.get());
+ step_stats_collector->Finalize();
+ return s;
+}
+
+Tracer::Tracer() {}
+
+} // namespace gpu
+} // namespace profiler
+} // namespace tensorflow
diff --git a/tensorflow/core/profiler/internal/gpu/tracer.h b/tensorflow/core/profiler/internal/gpu/tracer.h
new file mode 100644
index 0000000..d776543
--- /dev/null
+++ b/tensorflow/core/profiler/internal/gpu/tracer.h
@@ -0,0 +1,48 @@
+/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_GPU_TRACER_H_
+#define TENSORFLOW_CORE_PROFILER_INTERNAL_GPU_TRACER_H_
+
+#include "tensorflow/core/platform/device_tracer.h"
+#include "tensorflow/core/profiler/internal/profiler_interface.h"
+
+namespace tensorflow {
+namespace profiler {
+namespace gpu {
+
+class Tracer : public ProfilerInterface {
+ public:
+ static std::unique_ptr<ProfilerInterface> Create();
+
+ Status Start() override;
+
+ Status Stop() override;
+
+ Status CollectData(RunMetadata* run_metadata) override;
+
+ private:
+ Tracer();
+
+ // Trace is neither copyable nor movable.
+ Tracer(const Tracer&) = delete;
+ Tracer& operator=(const Tracer&) = delete;
+
+ std::unique_ptr<DeviceTracer> device_tracer_;
+};
+
+} // namespace gpu
+} // namespace profiler
+} // namespace tensorflow
+#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_GPU_TRACER_H_
diff --git a/tensorflow/core/profiler/internal/profiler_interface.h b/tensorflow/core/profiler/internal/profiler_interface.h
new file mode 100644
index 0000000..144c4bb
--- /dev/null
+++ b/tensorflow/core/profiler/internal/profiler_interface.h
@@ -0,0 +1,49 @@
+/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_PROFILER_INTERFACE_H_
+#define TENSORFLOW_CORE_PROFILER_INTERNAL_PROFILER_INTERFACE_H_
+
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/protobuf/config.pb.h"
+
+namespace tensorflow {
+namespace profiler {
+
+// Interface for tensorflow profiler plugins.
+//
+// ProfileSession calls each of these methods at most once per instance, and
+// implementations can rely on that guarantee for simplicity.
+//
+// Thread-safety: Implementations are only required to be go/thread-compatible.
+// ProfileSession is go/thread-safe and synchronizes access to ProfilerInterface
+// instances.
+class ProfilerInterface {
+ public:
+ virtual ~ProfilerInterface() = default;
+
+ // Starts profiling.
+ virtual Status Start() = 0;
+
+ // Stops profiling.
+ virtual Status Stop() = 0;
+
+ // Moves collected profile data into run_metadata.
+ virtual Status CollectData(RunMetadata* run_metadata) = 0;
+};
+
+} // namespace profiler
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_PROFILER_INTERFACE_H_
diff --git a/tensorflow/core/profiler/internal/runtime/BUILD b/tensorflow/core/profiler/internal/runtime/BUILD
new file mode 100644
index 0000000..2e383f1
--- /dev/null
+++ b/tensorflow/core/profiler/internal/runtime/BUILD
@@ -0,0 +1,24 @@
+package(
+ default_visibility = ["//tensorflow:internal"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_cuda_library",
+)
+
+tf_cuda_library(
+ name = "eager_profiler",
+ srcs = [
+ "eager_profiler.cc",
+ ],
+ hdrs = [
+ "eager_profiler.h",
+ ],
+ deps = [
+ "//tensorflow/core/common_runtime/eager:context",
+ "//tensorflow/core/profiler/internal:profiler_interface",
+ ],
+)
diff --git a/tensorflow/core/profiler/internal/runtime/eager_profiler.cc b/tensorflow/core/profiler/internal/runtime/eager_profiler.cc
new file mode 100644
index 0000000..aad692b
--- /dev/null
+++ b/tensorflow/core/profiler/internal/runtime/eager_profiler.cc
@@ -0,0 +1,61 @@
+/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/profiler/internal/runtime/eager_profiler.h"
+
+namespace tensorflow {
+namespace profiler {
+namespace runtime {
+
+TraceCollector::TraceCollector(EagerContext* const eager_context)
+ : context_(eager_context) {}
+
+void TraceCollector::BeforeClearRunMetadata() {
+ run_metadata_.MergeFrom(*context_->RunMetadataProto());
+}
+
+Status TraceCollector::CollectData(RunMetadata* run_metadata) {
+ run_metadata->MergeFrom(run_metadata_);
+ return Status::OK();
+}
+
+/* static */ std::unique_ptr<ProfilerInterface> EagerProfiler::Create(
+ EagerContext* const eager_context) {
+ return absl::WrapUnique(new EagerProfiler(eager_context));
+}
+
+Status EagerProfiler::Start() {
+ if (context_ == nullptr) {
+ return Status(tensorflow::error::Code::FAILED_PRECONDITION,
+ "No eager context attached.");
+ }
+ return context_->RegisterRunMetadataListener(&collector_);
+}
+
+Status EagerProfiler::Stop() {
+ collector_.BeforeClearRunMetadata();
+ context_->ClearRunMetadataListener();
+ return Status::OK();
+}
+
+Status EagerProfiler::CollectData(RunMetadata* run_metadata) {
+ return collector_.CollectData(run_metadata);
+}
+
+EagerProfiler::EagerProfiler(EagerContext* const eager_context)
+ : context_(eager_context), collector_(eager_context) {}
+
+} // namespace runtime
+} // namespace profiler
+} // namespace tensorflow
diff --git a/tensorflow/core/profiler/internal/runtime/eager_profiler.h b/tensorflow/core/profiler/internal/runtime/eager_profiler.h
new file mode 100644
index 0000000..7135355
--- /dev/null
+++ b/tensorflow/core/profiler/internal/runtime/eager_profiler.h
@@ -0,0 +1,64 @@
+/* Copyright 2016 The TensorFlow Authors All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_RUNTIME_EAGER_PROFILER_H_
+#define TENSORFLOW_CORE_PROFILER_INTERNAL_RUNTIME_EAGER_PROFILER_H_
+
+#include "tensorflow/core/common_runtime/eager/context.h"
+#include "tensorflow/core/profiler/internal/profiler_interface.h"
+
+namespace tensorflow {
+namespace profiler {
+namespace runtime {
+
+class TraceCollector : public RunMetadataListener {
+ public:
+ TraceCollector(EagerContext* const eager_context);
+
+ void BeforeClearRunMetadata() override;
+
+ Status CollectData(RunMetadata* run_metadata);
+
+ private:
+ RunMetadata run_metadata_;
+ EagerContext* const context_;
+};
+
+class EagerProfiler : public ProfilerInterface {
+ public:
+ static std::unique_ptr<ProfilerInterface> Create(
+ EagerContext* const eager_context);
+
+ Status Start() override;
+
+ Status Stop() override;
+
+ Status CollectData(RunMetadata* run_metadata) override;
+
+ private:
+ EagerProfiler(EagerContext* const eager_context);
+
+ // Trace is neither copyable nor movable.
+ EagerProfiler(const EagerProfiler&) = delete;
+ EagerProfiler& operator=(const EagerProfiler&) = delete;
+
+ EagerContext* const context_;
+ TraceCollector collector_;
+};
+
+} // namespace runtime
+} // namespace profiler
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_RUNTIME_EAGER_PROFILER_H_
diff --git a/tensorflow/core/profiler/internal/traceme_recorder.cc b/tensorflow/core/profiler/internal/traceme_recorder.cc
new file mode 100644
index 0000000..0369e0b
--- /dev/null
+++ b/tensorflow/core/profiler/internal/traceme_recorder.cc
@@ -0,0 +1,248 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/profiler/internal/traceme_recorder.h"
+
+// To avoid unneccesary synchronization between threads, each thread has a
+// ThreadLocalRecorder that independently records its events.
+//
+// Events are stored in an EventQueue implemented as a linked-list of blocks,
+// with start and end pointers:
+// [ events........ | next-]--> [ events......... | next ]
+// ^start_block ^start ^end_block ^end
+//
+// Record() writes at end, and then advances it, allocating a block if needed.
+// Clear() takes ownership of events in the range [start, end).
+// The end pointer is atomic so these can be concurrent.
+//
+// If a thread dies, the ThreadLocalRecorder's destructor hands its data off to
+// the orphaned_events list.
+
+#include <string>
+#include "absl/container/flat_hash_map.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/stream_executor/lib/initialize.h"
+
+namespace tensorflow {
+namespace profiler {
+
+// Default value for g_trace_level when tracing is disabled
+constexpr static int kTracingDisabled = -1;
+
+namespace internal {
+std::atomic<int> g_trace_level = ATOMIC_VAR_INIT(kTracingDisabled);
+} // namespace internal
+
+namespace {
+
+class ThreadLocalRecorder;
+
+struct Data {
+ // Lock for only rare events - start/stop, thread death.
+ mutex global_lock;
+ // Map of the static container instances (thread_local storage) for each
+ // thread, that store the trace events.
+ absl::flat_hash_map<uint64, ThreadLocalRecorder*> threads
+ GUARDED_BY(global_lock);
+ // Events traced from threads that died during tracing.
+ TraceMeRecorder::Events orphaned_events GUARDED_BY(global_lock);
+}* g_data = nullptr;
+
+// A single-producer single-consumer queue of Events.
+// Only the owner thread can write events, writing is lock-free.
+// Consume is also lock-free in this class.
+//
+// Internally, we have a linked list of blocks containing numbered slots.
+// start is the first occupied slot, end is the first unoccupied slot.
+class EventQueue {
+ public:
+ EventQueue()
+ : start_block_(new Block{0, nullptr}), end_block_(start_block_) {}
+
+ // REQUIRES: Consume() was called since the last Push().
+ // Memory should be deallocated and trace events destroyed on destruction.
+ // This doesn't require global lock as this discards all the stored trace
+ // events and we assume of destruction of this class only after the last
+ // Push() has been called.
+ ~EventQueue() {
+ DCHECK_EQ(start_, end_.load()) << "EventQueue destroyed without Consume()";
+ delete end_block_;
+ }
+
+ // Add a new event to the back of the queue. Fast and lock-free.
+ void Push(TraceMeRecorder::Event&& event) {
+ uint64 end = end_.load(std::memory_order_relaxed);
+ new (&end_block_->events[end++ - end_block_->start].event)
+ TraceMeRecorder::Event(std::move(event));
+ if (ABSL_PREDICT_FALSE(end - end_block_->start == Block::kLength)) {
+ auto* new_block = new Block{end, nullptr};
+ end_block_->next = new_block;
+ end_block_ = new_block;
+ }
+ end_.store(end, std::memory_order_release); // Write index after contents.
+ }
+
+ // Retrieve and remove all events in the queue.
+ std::vector<TraceMeRecorder::Event> Consume() {
+ // Read index before contents.
+ uint64 end = end_.load(std::memory_order_acquire);
+ std::vector<TraceMeRecorder::Event> result;
+ result.reserve(end - start_);
+ while (start_ != end) {
+ Shift(&result);
+ }
+ return result;
+ }
+
+ private:
+ // Shift one event off the front of the queue into *out.
+ void Shift(std::vector<TraceMeRecorder::Event>* out) {
+ // Move the next event into the output.
+ auto& event = start_block_->events[start_++ - start_block_->start].event;
+ out->push_back(std::move(event));
+ event.~Event(); // Events must be individually destroyed.
+ // If we reach the end of a block, we own it and should delete it.
+ // The next block is present: end always points to something.
+ if (start_ - start_block_->start == Block::kLength) {
+ auto* next_block = start_block_->next;
+ delete start_block_;
+ start_block_ = next_block;
+ }
+ }
+
+ // The number of slots in a block. Chosen so that the block fits in 64k.
+ struct Block {
+ static constexpr size_t kLength =
+ ((1 << 16) - (sizeof(uint64) + sizeof(std::atomic<Block*>))) /
+ sizeof(TraceMeRecorder::Event);
+
+ const uint64 start; // The number of the first slot.
+ Block* next;
+ // Defer construction of Event until the data is available.
+ // Must also destroy manually, as the block may not fill entirely.
+ union MaybeEvent {
+ MaybeEvent() {}
+ ~MaybeEvent() {}
+ TraceMeRecorder::Event event;
+ } events[kLength];
+ };
+
+ // Head of list for reading. Only accessed by consumer thread.
+ Block* start_block_;
+ uint64 start_ = 0;
+ // Tail of list for writing. Accessed by producer thread.
+ Block* end_block_;
+ std::atomic<uint64> end_ = {0}; // Atomic: also read by consumer thread.
+};
+
+class ThreadLocalRecorder {
+ public:
+ // The recorder is created the first time Record() is called on a thread.
+ ThreadLocalRecorder() {
+ auto* env = Env::Default();
+ info_.tid = env->GetCurrentThreadId();
+ env->GetCurrentThreadName(&info_.name);
+ mutex_lock lock(g_data->global_lock);
+ g_data->threads.emplace(info_.tid, this);
+ }
+
+ // The destructor is called when the thread shuts down early.
+ // We unregister this thread, and move its events to orphaned_events.
+ ~ThreadLocalRecorder() {
+ mutex_lock lock(g_data->global_lock);
+ g_data->threads.erase(info_.tid);
+ g_data->orphaned_events.push_back(Clear());
+ }
+
+ // This is the performance-critical part!
+ void Record(TraceMeRecorder::Event&& event) { queue_.Push(std::move(event)); }
+
+ TraceMeRecorder::ThreadEvents Clear()
+ EXCLUSIVE_LOCKS_REQUIRED(g_data->global_lock) {
+ return {info_, queue_.Consume()};
+ }
+
+ private:
+ TraceMeRecorder::ThreadInfo info_;
+ EventQueue queue_;
+};
+
+// Gather events from all active threads, and clear their buffers. The global
+// lock is held, so no threads can be added/removed for the duration while we
+// consume the collected trace entries. This will block any new thread and also
+// the starting and stopping of TraceMeRecorder, hence, this is performance
+// critical and should be kept fast.
+TraceMeRecorder::Events Clear() EXCLUSIVE_LOCKS_REQUIRED(g_data->global_lock) {
+ TraceMeRecorder::Events result;
+ std::swap(g_data->orphaned_events, result);
+ for (const auto& entry : g_data->threads) {
+ auto* recorder = entry.second;
+ result.push_back(recorder->Clear());
+ }
+ return result;
+}
+
+} // namespace
+
+bool TraceMeRecorder::Start(int level) {
+ level = std::max(0, level);
+ mutex_lock lock(g_data->global_lock);
+ int expected = kTracingDisabled;
+ if (!internal::g_trace_level.compare_exchange_strong(
+ expected, level, std::memory_order_acq_rel)) {
+ return false;
+ }
+ // We may have old events in buffers because Record() raced with Stop().
+ Clear();
+ return true;
+}
+
+
+void TraceMeRecorder::Record(Event event) {
+ static thread_local ThreadLocalRecorder thread_local_recorder;
+ thread_local_recorder.Record(std::move(event));
+}
+
+// Only one thread is expected to call Stop() as first instance of XprofSession
+// prevents another XprofSession from doing any profiling.
+TraceMeRecorder::Events TraceMeRecorder::Stop() {
+ mutex_lock lock(g_data->global_lock);
+ if (internal::g_trace_level.exchange(
+ kTracingDisabled, std::memory_order_acq_rel) == kTracingDisabled) {
+ return {};
+ }
+ return Clear();
+}
+
+TraceMeRecorder::Events TraceMeRecorder::Collect() {
+ mutex_lock lock(g_data->global_lock);
+ if (internal::g_trace_level.load(std::memory_order_acquire) ==
+ kTracingDisabled) {
+ return {};
+ }
+ return Clear();
+}
+
+} // namespace profiler
+} // namespace tensorflow
+
+REGISTER_MODULE_INITIALIZER(traceme_recorder, {
+ tensorflow::profiler::g_data = new tensorflow::profiler::Data();
+
+ // Workaround for b/35097229, the first block-scoped thread_local can
+ // trigger false positives in the heap checker. Currently triggered by
+ // //perftools/accelerators/xprof/xprofilez/integration_tests:xla_hlo_trace_test
+ static thread_local tensorflow::string fix_deadlock ABSL_ATTRIBUTE_UNUSED;
+});
diff --git a/tensorflow/core/profiler/internal/traceme_recorder.h b/tensorflow/core/profiler/internal/traceme_recorder.h
new file mode 100644
index 0000000..1e66b1e
--- /dev/null
+++ b/tensorflow/core/profiler/internal/traceme_recorder.h
@@ -0,0 +1,95 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TRACEME_RECORDER_H_
+#define TENSORFLOW_CORE_PROFILER_INTERNAL_TRACEME_RECORDER_H_
+
+#include <atomic>
+#include <vector>
+#include "absl/base/optimization.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace profiler {
+
+namespace internal {
+extern std::atomic<int> g_trace_level;
+} // namespace internal
+
+// TraceMeRecorder is a singleton repository of TraceMe events.
+// It can be safely and cheaply appended to by multiple threads.
+//
+// Start() and Stop() must be called in pairs, Stop() returns the events added
+// since the previous Start().
+//
+// This is the backend for TraceMe instrumentation.
+// The profiler starts the recorder, the TraceMe constructor records begin
+// events, and the destructor records end events.
+// The profiler then stops the recorder and finds start/end pairs. (Unpaired
+// start/end events are discarded at that point).
+class TraceMeRecorder {
+ public:
+ // An Event is either the start of a TraceMe, the end of a TraceMe, or both.
+ // Times are in ns since the Unix epoch.
+ struct Event {
+ uint64 activity_id;
+ string name;
+ uint64 start_time; // 0 = missing
+ uint64 end_time; // 0 = missing
+ };
+ struct ThreadInfo {
+ int64 tid;
+ string name;
+ };
+ struct ThreadEvents {
+ const ThreadInfo thread;
+ std::vector<Event> events;
+ };
+ using Events = std::vector<ThreadEvents>;
+
+ // Starts recording of TraceMe().
+ // Only traces <= level will be recorded.
+ // Level must be >= 0.
+ // If level is 0, no traces will be recorded.
+ static bool Start(int level);
+
+ // Stops recording and returns events recorded since Start().
+ static Events Stop();
+
+ // Returns events recorded till now without stopping the recording. Empty
+ // container is returned if the recorder was already stopped.
+ static Events Collect();
+
+ // Returns whether we're currently recording. Racy, but cheap!
+ static inline bool Active(int level = 1) {
+ return ABSL_PREDICT_FALSE(
+ internal::g_trace_level.load(std::memory_order_acquire) >= level);
+ }
+
+ static void Record(Event);
+
+ private:
+ // No copy and assignment
+ TraceMeRecorder(const TraceMeRecorder&) = delete;
+ TraceMeRecorder& operator=(const TraceMeRecorder&) = delete;
+
+ // Implementation of g_trace_level must be lock-free for faster execution
+ // of the TraceMe() public API. This can be commented (if compilation is
+ // failing) but execution might be slow (even when host tracing is disabled).
+ static_assert(ATOMIC_INT_LOCK_FREE == 2, "Assumed atomic<int> was lock free");
+};
+
+} // namespace profiler
+} // namespace tensorflow
+#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TRACEME_RECORDER_H_
diff --git a/tensorflow/core/profiler/internal/traceme_recorder_test.cc b/tensorflow/core/profiler/internal/traceme_recorder_test.cc
new file mode 100644
index 0000000..ec588af
--- /dev/null
+++ b/tensorflow/core/profiler/internal/traceme_recorder_test.cc
@@ -0,0 +1,211 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/profiler/internal/traceme_recorder.h"
+
+#include <atomic>
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/synchronization/notification.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/env_time.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace profiler {
+namespace {
+
+MATCHER_P(Named, name, "") { return arg.name == name; }
+
+constexpr static uint64 kNanosInSec = 1000000000;
+
+TEST(RecorderTest, SingleThreaded) {
+ uint64 start_time = Env::Default()->NowNanos();
+ uint64 end_time = start_time + kNanosInSec;
+
+ TraceMeRecorder::Record({1, "before", start_time, end_time});
+ TraceMeRecorder::Start(/*level=*/1);
+ TraceMeRecorder::Record({2, "during1", start_time, end_time});
+ TraceMeRecorder::Record({3, "during2", start_time, end_time});
+ auto results = TraceMeRecorder::Stop();
+ TraceMeRecorder::Record({4, "after", start_time, end_time});
+
+ ASSERT_EQ(results.size(), 1);
+ EXPECT_THAT(results[0].events,
+ ::testing::ElementsAre(Named("during1"), Named("during2")));
+}
+
+TEST(RecorderTest, CollectionBeforeStop) {
+ uint64 start_time = Env::Default()->NowNanos();
+ uint64 end_time = start_time + kNanosInSec;
+
+ TraceMeRecorder::Record({1, "ignored", start_time, end_time});
+ TraceMeRecorder::Start(/*level=*/1);
+ TraceMeRecorder::Record({2, "during1", start_time, end_time});
+ TraceMeRecorder::Record({3, "during2", start_time, end_time});
+ auto collected_results = TraceMeRecorder::Collect();
+ TraceMeRecorder::Record({4, "after_collect", start_time, end_time});
+ auto stopped_results = TraceMeRecorder::Stop();
+ TraceMeRecorder::Record({5, "after_stop", start_time, end_time});
+ auto results_after_stop = TraceMeRecorder::Collect();
+
+ ASSERT_EQ(collected_results.size(), 1);
+ EXPECT_THAT(collected_results[0].events,
+ ::testing::ElementsAre(Named("during1"), Named("during2")));
+
+ ASSERT_EQ(stopped_results.size(), 1);
+ EXPECT_THAT(stopped_results[0].events,
+ ::testing::ElementsAre(Named("after_collect")));
+
+ ASSERT_EQ(results_after_stop.size(), 0);
+}
+
+void SpinNanos(int nanos) {
+ uint64 deadline = Env::Default()->NowNanos() + nanos;
+ while (Env::Default()->NowNanos() < deadline) {
+ }
+}
+
+// Checks the functional behavior of the recorder, when used from several
+// unsynchronized threads.
+//
+// Each thread records a stream of events.
+// Thread 0: activity=0, activity=1, activity=2, ...
+// Thread 1: activity=0, activity=1, activity=2, ...
+// ...
+//
+// We turn the recorder on and off repeatedly in sessions, expecting to see:
+// - data from every thread (eventually - maybe not every session)
+// - unbroken sessions: a consecutive sequence of IDs from each thread
+// - gaps between sessions: a thread's IDs should be non-consecutive overall
+TEST(RecorderTest, Multithreaded) {
+ constexpr static int kNumThreads = 4;
+
+ // Start several threads writing events.
+ absl::Notification start;
+ absl::Notification stop;
+ thread::ThreadPool pool(Env::Default(), "testpool", kNumThreads);
+ std::atomic<int> thread_count = {0};
+ for (int i = 0; i < kNumThreads; i++) {
+ pool.Schedule([&start, &stop, &thread_count, i] {
+ uint64 j = 0;
+ bool was_active = false;
+ auto record_event = [&j, i]() {
+ uint64 start_time = Env::Default()->NowNanos();
+ uint64 end_time = start_time + kNanosInSec;
+ TraceMeRecorder::Record({/*activity_id=*/j++,
+ /*name=*/strings::StrCat(i), start_time,
+ end_time});
+ };
+ thread_count.fetch_add(1, std::memory_order_relaxed);
+ start.WaitForNotification();
+ while (!stop.HasBeenNotified()) {
+ // Mimicking production usage, we guard with a racy check.
+ // In principle this isn't needed, but a feedback loop can form:
+ // 1) many events accumulate while the recorder is off
+ // 2) clearing/analyzing these events is slow
+ // 3) while clearing, more events are accumulating, causing 1
+ if (TraceMeRecorder::Active()) {
+ record_event();
+ was_active = true;
+ }
+ // Record some events after the recorder is no longer active to simulate
+ // point 1 and 3.
+ if (was_active && !TraceMeRecorder::Active()) {
+ record_event();
+ record_event();
+ was_active = false;
+ }
+ // This snowballs into OOM in some configurations, causing flakiness.
+ // Keep this big enough to prevent OOM and small enough such that
+ // each thread records at least one event.
+ SpinNanos(10);
+ }
+ });
+ }
+
+ // For each thread, keep track of which events we've seen.
+ struct {
+ bool split_session = false;
+ bool overlapping_sessions = false;
+ std::set<uint64> events;
+ } thread_state[kNumThreads];
+ // We expect each thread to eventually have multiple events, not all in a
+ // contiguous range.
+ auto done = [&thread_state] {
+ for (const auto& t : thread_state) {
+ if (t.events.size() < 2) return false;
+ }
+ return true;
+ };
+
+ // Wait while all the threads are spun up.
+ while (thread_count.load(std::memory_order_relaxed) < kNumThreads) {
+ LOG(INFO) << "Waiting for all threads to spin up...";
+ Env::Default()->SleepForMicroseconds(1 * EnvTime::kMillisToMicros);
+ }
+
+ // We will probably be done after two iterations (with each thread getting
+ // some events each iteration). No guarantees as all the threads might not get
+ // scheduled in a session, so try for a while.
+ start.Notify();
+ constexpr static int kMaxIters = 100;
+ for (int iters = 0; iters < kMaxIters && !done(); ++iters) {
+ LOG(INFO) << "Looping until convergence, iteration: " << iters;
+ TraceMeRecorder::Start(/*level=*/1);
+ Env::Default()->SleepForMicroseconds(100 * EnvTime::kMillisToMicros);
+ auto results = TraceMeRecorder::Stop();
+ for (const auto& thread : results) {
+ if (thread.events.empty()) continue;
+ std::istringstream ss(thread.events.front().name);
+ int thread_index = 0;
+ ss >> thread_index;
+ auto& state = thread_state[thread_index];
+
+ std::set<uint64> session_events;
+ uint64 current = 0;
+ for (const auto& event : thread.events) {
+ session_events.emplace(event.activity_id);
+ // Session events should be contiguous.
+ if (current != 0 && event.activity_id != current + 1) {
+ state.split_session = true;
+ }
+ current = event.activity_id;
+ }
+
+ for (const auto& event : session_events) {
+ auto result = state.events.emplace(event);
+ if (!result.second) {
+ // Session events should not overlap with those from previous
+ // sessions.
+ state.overlapping_sessions = true;
+ }
+ }
+ }
+ Env::Default()->SleepForMicroseconds(1 * EnvTime::kMillisToMicros);
+ }
+ stop.Notify();
+
+ for (const auto& thread : thread_state) {
+ EXPECT_FALSE(thread.split_session)
+ << "Expected contiguous events in a session";
+ EXPECT_FALSE(thread.overlapping_sessions) << "Expected disjoint sessions";
+ EXPECT_GT(thread.events.size(), 1)
+ << "Expected gaps in thread events between sessions";
+ }
+}
+
+} // namespace
+} // namespace profiler
+} // namespace tensorflow
diff --git a/tensorflow/core/profiler/lib/BUILD b/tensorflow/core/profiler/lib/BUILD
index 0320ae1..e12cc1e 100644
--- a/tensorflow/core/profiler/lib/BUILD
+++ b/tensorflow/core/profiler/lib/BUILD
@@ -13,17 +13,20 @@
)
tf_cuda_library(
- name = "eager_profiler",
+ name = "profiler_session",
srcs = [
- "eager_profiler.cc",
+ "profiler_session.cc",
],
hdrs = [
- "eager_profiler.h",
+ "profiler_session.h",
],
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/core/common_runtime/eager:context",
"//tensorflow/contrib/tpu/profiler:trace_events_proto_cc",
+ "//tensorflow/core/profiler/internal/gpu:tracer",
+ "//tensorflow/core/profiler/internal/runtime:eager_profiler",
+ "//tensorflow/core/profiler/internal:profiler_interface",
] + select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
@@ -40,3 +43,16 @@
],
}),
)
+
+tf_cuda_library(
+ name = "traceme",
+ srcs = ["traceme.cc"],
+ hdrs = ["traceme.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:lib",
+ "//tensorflow/core/profiler/internal:traceme_recorder",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
+ ],
+)
diff --git a/tensorflow/core/profiler/lib/eager_profiler.cc b/tensorflow/core/profiler/lib/eager_profiler.cc
deleted file mode 100644
index 9293e7a..0000000
--- a/tensorflow/core/profiler/lib/eager_profiler.cc
+++ /dev/null
@@ -1,162 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/profiler/lib/eager_profiler.h"
-#include <string>
-#include "tensorflow/contrib/tpu/profiler/trace_events.pb.h"
-#include "tensorflow/core/common_runtime/eager/context.h"
-#include "tensorflow/core/common_runtime/step_stats_collector.h"
-#include "tensorflow/core/framework/graph.pb.h"
-#include "tensorflow/core/platform/device_tracer.h"
-#include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/platform/mutex.h"
-#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/protobuf/config.pb.h"
-
-namespace tensorflow {
-
-namespace {
-
-void ConvertRunMetadataToTraceEvent(RunMetadata* run_metadata,
- tpu::Trace* trace,
- const uint64 profile_start_time_micros) {
- auto trace_devices = trace->mutable_devices();
- // TODO(fishx): use a lighter representation instead of GraphDef to insert
- // python information into trace event.
-
- for (size_t device_id = 0;
- device_id < run_metadata->step_stats().dev_stats_size(); ++device_id) {
- // Create device
- auto* device_stats =
- run_metadata->mutable_step_stats()->mutable_dev_stats(device_id);
- tensorflow::tpu::Device device;
- device.set_name(device_stats->device());
- device.set_device_id(device_id);
- tensorflow::tpu::Resource resource;
- resource.set_name("0");
- resource.set_resource_id(0);
- (*device.mutable_resources())[0] = resource;
- (*trace_devices)[device_id] = device;
-
- // Emit events.
- for (auto node :
- run_metadata->step_stats().dev_stats(device_id).node_stats()) {
- auto* event = trace->add_trace_events();
- auto* args = event->mutable_args();
- event->set_device_id(device_id);
- event->set_resource_id(0);
- event->set_name(node.node_name());
- event->set_timestamp_ps(
- (node.all_start_micros() - profile_start_time_micros) *
- EnvTime::kMicrosToPicos);
- event->set_duration_ps(node.all_end_rel_micros() *
- EnvTime::kMicrosToPicos);
- (*args)["label"] = node.timeline_label();
- }
- }
-
- // TODO(fishx): Convert allocation data as well.
-}
-
-} // namespace
-
-/*static*/ std::unique_ptr<EagerProfiler> EagerProfiler::Create(
- EagerContext* const context) {
- return absl::WrapUnique(new EagerProfiler(context));
-}
-
-void EagerProfiler::BeforeClearRunMetadata() {
- mutex_lock l(mutex_);
- run_metadata_.MergeFrom(*context_->RunMetadataProto());
-}
-
-Status EagerProfiler::Status() {
- mutex_lock l(mutex_);
- return status_;
-}
-
-Status EagerProfiler::SerializeToString(string* content) {
- mutex_lock l(mutex_);
- if (!status_.ok()) return status_;
- Stop();
-
- // Get profiling data from device tracer
- if (device_tracer_ != nullptr) {
- std::unique_ptr<StepStatsCollector> step_stats_collector(
- new StepStatsCollector(run_metadata_.mutable_step_stats()));
- tensorflow::Status s = device_tracer_->Collect(step_stats_collector.get());
- if (!s.ok()) {
- device_tracer_.reset(nullptr);
- LOG(WARNING) << "Failed to collect data from device tracer. "
- << s.error_message();
- }
- step_stats_collector->Finalize();
- }
-
- tpu::Trace trace;
-
- ConvertRunMetadataToTraceEvent(&run_metadata_, &trace, start_time_micros_);
-
- trace.SerializeToString(content);
- return Status::OK();
-}
-
-EagerProfiler::EagerProfiler(EagerContext* const context)
- : context_(context),
- start_time_micros_(Env::Default()->NowNanos() / EnvTime::kMicrosToNanos) {
- LOG(INFO) << "Eager Profiler started.";
-
- status_ = context_->RegisterRunMetadataListener(this);
- if (!status_.ok()) {
- context_ = nullptr;
- LOG(WARNING)
- << "Eager Profiler failed to start. Another profiler is running.";
- return;
- }
-
- // TODO(fishx): Allow user disable device tracer.
- device_tracer_ = CreateDeviceTracer();
- if (!device_tracer_) {
- LOG(WARNING) << "Continue profiling without device tracer. "
- << "Failed to create device tracer.";
- return;
- }
- class Status s = device_tracer_->Start();
- if (!s.ok()) {
- device_tracer_.reset(nullptr);
- LOG(WARNING) << "Continue profiling without device tracer. "
- << s.error_message();
- }
-}
-
-EagerProfiler::~EagerProfiler() { Stop(); }
-
-void EagerProfiler::Stop() {
- if (context_ != nullptr) {
- context_->ClearRunMetadataListener();
- run_metadata_.MergeFrom(*context_->RunMetadataProto());
- context_ = nullptr;
- if (device_tracer_ != nullptr) {
- tensorflow::Status s = device_tracer_->Stop();
- if (!s.ok()) {
- device_tracer_.reset(nullptr);
- LOG(WARNING) << "Failed to stop device tracer. " << s.error_message();
- }
- }
- LOG(INFO) << "Eager Profiler ended with status:" << status_;
- }
-}
-
-} // namespace tensorflow
diff --git a/tensorflow/core/profiler/lib/profiler_session.cc b/tensorflow/core/profiler/lib/profiler_session.cc
new file mode 100644
index 0000000..1eb9ed6
--- /dev/null
+++ b/tensorflow/core/profiler/lib/profiler_session.cc
@@ -0,0 +1,136 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/profiler/lib/profiler_session.h"
+#include <string>
+#include "tensorflow/contrib/tpu/profiler/trace_events.pb.h"
+#include "tensorflow/core/common_runtime/eager/context.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/internal/gpu/tracer.h"
+#include "tensorflow/core/profiler/internal/runtime/eager_profiler.h"
+#include "tensorflow/core/protobuf/config.pb.h"
+
+namespace tensorflow {
+
+namespace {
+
+void ConvertRunMetadataToTraceEvent(RunMetadata* run_metadata,
+ tpu::Trace* trace,
+ const uint64 profile_start_time_micros) {
+ auto trace_devices = trace->mutable_devices();
+ // TODO(fishx): use a lighter representation instead of GraphDef to insert
+ // python information into trace event.
+
+ for (size_t device_id = 0;
+ device_id < run_metadata->step_stats().dev_stats_size(); ++device_id) {
+ // Create device
+ auto* device_stats =
+ run_metadata->mutable_step_stats()->mutable_dev_stats(device_id);
+ tensorflow::tpu::Device device;
+ device.set_name(device_stats->device());
+ device.set_device_id(device_id);
+ tensorflow::tpu::Resource resource;
+ resource.set_name("0");
+ resource.set_resource_id(0);
+ (*device.mutable_resources())[0] = resource;
+ for (const auto& thread_name : device_stats->thread_names()) {
+ tensorflow::tpu::Resource resource;
+ resource.set_resource_id(thread_name.first);
+ resource.set_name(thread_name.second);
+ (*device.mutable_resources())[thread_name.first] = resource;
+ }
+ (*trace_devices)[device_id] = device;
+
+ // Emit events.
+ for (auto node :
+ run_metadata->step_stats().dev_stats(device_id).node_stats()) {
+ auto* event = trace->add_trace_events();
+ auto* args = event->mutable_args();
+ event->set_device_id(device_id);
+ if (device_stats->device().find("host:CPU") != string::npos) {
+ event->set_resource_id(node.thread_id());
+ } else {
+ event->set_resource_id(0);
+ }
+ event->set_name(node.node_name());
+ event->set_timestamp_ps(
+ (node.all_start_micros() - profile_start_time_micros) *
+ EnvTime::kMicrosToPicos);
+ event->set_duration_ps(node.all_end_rel_micros() *
+ EnvTime::kMicrosToPicos);
+ (*args)["label"] = node.timeline_label();
+ }
+ }
+
+ // TODO(fishx): Convert allocation data as well.
+}
+
+} // namespace
+
+/*static*/ std::unique_ptr<ProfilerSession> ProfilerSession::Create(
+ EagerContext* const context) {
+ return absl::WrapUnique(new ProfilerSession(context));
+}
+
+Status ProfilerSession::Status() {
+ mutex_lock l(mutex_);
+ return status_;
+}
+
+Status ProfilerSession::SerializeToString(string* content) {
+ mutex_lock l(mutex_);
+ if (!status_.ok()) return status_;
+ for (auto& profiler : profilers_) {
+ profiler->Stop().IgnoreError();
+ }
+ RunMetadata run_metadata;
+ for (auto& profiler : profilers_) {
+ profiler->CollectData(&run_metadata).IgnoreError();
+ }
+
+ tpu::Trace trace;
+
+ ConvertRunMetadataToTraceEvent(&run_metadata, &trace, start_time_micros_);
+
+ trace.SerializeToString(content);
+ return Status::OK();
+}
+
+ProfilerSession::ProfilerSession(EagerContext* const context)
+ : start_time_micros_(Env::Default()->NowNanos() / EnvTime::kMicrosToNanos) {
+ LOG(INFO) << "Profile Session started.";
+
+ if (context != nullptr) {
+ profilers_.push_back(
+ tensorflow::profiler::runtime::EagerProfiler::Create(context));
+ }
+ profilers_.push_back(tensorflow::profiler::gpu::Tracer::Create());
+
+ status_ = Status::OK();
+
+ for (auto& profiler : profilers_) {
+ profiler->Start().IgnoreError();
+ }
+}
+
+ProfilerSession::~ProfilerSession() {
+ for (auto& profiler : profilers_) {
+ profiler->Stop().IgnoreError();
+ }
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/profiler/lib/eager_profiler.h b/tensorflow/core/profiler/lib/profiler_session.h
similarity index 62%
rename from tensorflow/core/profiler/lib/eager_profiler.h
rename to tensorflow/core/profiler/lib/profiler_session.h
index 7cdb76f..1ab4825 100644
--- a/tensorflow/core/profiler/lib/eager_profiler.h
+++ b/tensorflow/core/profiler/lib/profiler_session.h
@@ -12,14 +12,13 @@
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PROFILER_H_
-#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PROFILER_H_
+#ifndef TENSORFLOW_CORE_PROFILER_LIB_PROFILER_SESSION_H_
+#define TENSORFLOW_CORE_PROFILER_LIB_PROFILER_SESSION_H_
#include "tensorflow/core/common_runtime/eager/context.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/platform/device_tracer.h"
#include "tensorflow/core/platform/mutex.h"
-#include "tensorflow/core/protobuf/config.pb.h"
+#include "tensorflow/core/profiler/internal/profiler_interface.h"
namespace tensorflow {
@@ -29,38 +28,34 @@
// Multiple instances of it can be created, but at most one of them will profile
// for each EagerContext. Status() will return OK only for the instance that is
// profiling.
-// Thread-safety: TFE_Profiler is thread-safe.
-class EagerProfiler : RunMetadataListener {
+// Thread-safety: ProfilerSession is thread-safe.
+class ProfilerSession {
public:
- // Creates and EagerProfiler and starts profiling.
- static std::unique_ptr<EagerProfiler> Create(EagerContext* const context);
+ // Creates and ProfilerSession and starts profiling.
+ static std::unique_ptr<ProfilerSession> Create(EagerContext* const context);
// Deletes an exsiting Profiler and enables starting a new one.
- ~EagerProfiler() override;
+ ~ProfilerSession();
- void BeforeClearRunMetadata() override LOCKS_EXCLUDED(mutex_)
- EXCLUSIVE_LOCKS_REQUIRED(context_->MetadataMu());
tensorflow::Status Status() LOCKS_EXCLUDED(mutex_);
tensorflow::Status SerializeToString(string* content) LOCKS_EXCLUDED(mutex_);
private:
// Constructs an instance of the class and starts profiling
- explicit EagerProfiler(EagerContext* const context);
+ explicit ProfilerSession(EagerContext* const context);
// Profiler is neither copyable or movable.
- EagerProfiler(const EagerProfiler&) = delete;
- EagerProfiler& operator=(const EagerProfiler&) = delete;
+ ProfilerSession(const ProfilerSession&) = delete;
+ ProfilerSession& operator=(const ProfilerSession&) = delete;
- void Stop() EXCLUSIVE_LOCKS_REQUIRED(mutex_);
+ std::vector<std::unique_ptr<tensorflow::profiler::ProfilerInterface>>
+ profilers_ GUARDED_BY(mutex_);
- RunMetadata run_metadata_ GUARDED_BY(mutex_);
tensorflow::Status status_ GUARDED_BY(mutex_);
- std::unique_ptr<DeviceTracer> device_tracer_ GUARDED_BY(mutex_);
- EagerContext* context_ GUARDED_BY(mutex_);
const uint64 start_time_micros_;
mutex mutex_;
};
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_PROFILER_H_
+#endif // TENSORFLOW_CORE_PROFILER_LIB_PROFILER_SESSION_H_
diff --git a/tensorflow/core/profiler/lib/traceme.cc b/tensorflow/core/profiler/lib/traceme.cc
new file mode 100644
index 0000000..90272b8
--- /dev/null
+++ b/tensorflow/core/profiler/lib/traceme.cc
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/profiler/lib/traceme.h"
+
+namespace tensorflow {
+namespace profiler {
+
+// Activity IDs: To avoid contention over a counter, the top 32 bits identify
+// the originating thread, the bottom 32 bits name the event within a thread.
+// IDs may be reused after 4 billion events on one thread, or 4 billion threads.
+static std::atomic<uint32> thread_counter(1); // avoid kUntracedActivity
+uint64 NewActivityId() {
+ const thread_local static uint32 thread_id = thread_counter.fetch_add(1);
+ thread_local static uint32 per_thread_activity_id = 0;
+ return static_cast<uint64>(thread_id) << 32 | per_thread_activity_id++;
+}
+
+/* static */ uint64 TraceMe::ActivityStartImpl(
+ absl::string_view activity_name) {
+ uint64 activity_id = NewActivityId();
+ TraceMeRecorder::Record({activity_id, string(activity_name),
+ /*start_time=*/Env::Default()->NowNanos(),
+ /*end_time=*/0});
+ return activity_id;
+}
+
+/* static */ void TraceMe::ActivityEndImpl(uint64 activity_id) {
+ TraceMeRecorder::Record({activity_id, /*name=*/"", /*start_time=*/0,
+ /*end_time=*/Env::Default()->NowNanos()});
+}
+
+} // namespace profiler
+} // namespace tensorflow
diff --git a/tensorflow/core/profiler/lib/traceme.h b/tensorflow/core/profiler/lib/traceme.h
new file mode 100644
index 0000000..b9fae3d
--- /dev/null
+++ b/tensorflow/core/profiler/lib/traceme.h
@@ -0,0 +1,192 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_PROFILER_LIB_TRACEME_H_
+#define TENSORFLOW_CORE_PROFILER_LIB_TRACEME_H_
+
+#include <string>
+
+#include "absl/strings/string_view.h"
+#include "absl/types/optional.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/profiler/internal/traceme_recorder.h"
+
+namespace tensorflow {
+namespace profiler {
+
+// This is specifically used in xprof_bridge for instrumenting Tensorflow ops.
+// Takes input as whether a TF op is expensive or not and returns the TraceMe
+// level to be assigned to trace that particular op. Assigns level 2 for
+// expensive ops (these are high-level details and shown by default in xprof
+// UI). Assigns level 3 for cheap ops (low-level details not shown by default).
+inline int GetTFTraceMeLevel(bool is_expensive) { return is_expensive ? 2 : 3; }
+
+// This class permits user-specified (CPU) tracing activities. A trace activity
+// is started when an object of this class is created and stopped when the
+// object is destroyed.
+//
+// CPU tracing can be useful when trying to understand what parts of GPU
+// computation (e.g., kernels and memcpy) correspond to higher level activities
+// in the overall program. For instance, a collection of kernels maybe
+// performing one "step" of a program that is better visualized together than
+// interspersed with kernels from other "steps". Therefore, a TraceMe object
+// can be created at each "step".
+//
+// Two APIs are provided:
+// (1) Scoped object: a TraceMe object starts tracing on construction, and
+// stops tracing when it goes out of scope.
+// {
+// TraceMe trace("step");
+// ... do some work ...
+// }
+// TraceMe objects can be members of a class, or allocated on the heap.
+// (2) Static methods: ActivityStart and ActivityEnd may be called in pairs.
+// auto id = ActivityStart("step");
+// ... do some work ...
+// ActivityEnd(id);
+class TraceMe {
+ public:
+ // Constructor that traces a user-defined activity labeled with activity_name
+ // in the UI. Level defines the trace priority, used for filtering TraceMe
+ // events. By default, traces with TraceMe level <= 2 are recorded. Levels:
+ // - Must be a positive integer.
+ // - Level 1 is the default and used only for user instrumentation.
+ // - Level 2 is used by xprof for instrumenting high level program execution
+ // details (expensive TF ops, XLA ops, etc).
+ // - Level 3 is also used by xprof to instrument more verbose (low-level)
+ // program execution details (cheap TF ops, etc).
+ // Users are welcome to use level >= 2 in their code, if they wish to filter
+ // out their host traces based on verbosity.
+ explicit TraceMe(absl::string_view activity_name, int level = 1) {
+ DCHECK_GE(level, 1);
+ if (TraceMeRecorder::Active(level)) {
+ new (&no_init_.name) string(activity_name);
+ start_time_ = Env::Default()->NowNanos();
+ } else {
+ start_time_ = kUntracedActivity;
+ }
+ }
+
+ // string&& constructor to prevent an unnecessary string copy, e.g. when a
+ // TraceMe is constructed based on the result of a StrCat operation.
+ // Note: We can't take the string by value because a) it would make the
+ // overloads ambiguous, and b) we want lvalue strings to use the string_view
+ // constructor so we avoid copying them when tracing is disabled.
+ explicit TraceMe(string &&activity_name, int level = 1) {
+ DCHECK_GE(level, 1);
+ if (TraceMeRecorder::Active(level)) {
+ new (&no_init_.name) string(std::move(activity_name));
+ start_time_ = Env::Default()->NowNanos();
+ } else {
+ start_time_ = kUntracedActivity;
+ }
+ }
+
+ // Do not allow passing strings by reference or value since the caller
+ // may unintentionally maintain ownership of the activity_name.
+ // Explicitly std::move the activity_name or wrap it in a string_view if
+ // you really wish to maintain ownership.
+ explicit TraceMe(const string &activity_name, int level = 1) = delete;
+
+ // This overload is necessary to make TraceMe's with string literals work.
+ // Otherwise, the string&& and the string_view constructor would be equally
+ // good overload candidates.
+ explicit TraceMe(const char *raw, int level = 1)
+ : TraceMe(absl::string_view(raw), level) {}
+
+ // This overload only generates the activity name if tracing is enabled.
+ // Useful for avoiding things like string concatenation when tracing is
+ // disabled. The |name_generator| may be a lambda or functor that returns a
+ // type that the string() constructor can take.
+ // name_generator is templated, rather than a std::function to avoid
+ // allocations std::function might make even if never called.
+ // Usage: xprof::TraceMe([&]{ return StrCat(prefix, ":", postfix); });
+ template <typename NameGeneratorT>
+ explicit TraceMe(NameGeneratorT name_generator, int level = 1) {
+ DCHECK_GE(level, 1);
+ if (TraceMeRecorder::Active(level)) {
+ new (&no_init_.name) string(name_generator());
+ start_time_ = Env::Default()->NowNanos();
+ } else {
+ start_time_ = kUntracedActivity;
+ }
+ }
+
+ ~TraceMe() {
+ // We do not need to check the trace level again here.
+ // - If tracing wasn't active to start with, we have kUntracedActivity.
+ // - If tracing was active and was stopped, we have
+ // TraceMeRecorder::Active().
+ // - If tracing was active and was restarted at a lower level, we may
+ // spuriously record the event. This is extremely rare, and acceptable as
+ // event will be discarded when its start timestamp fall outside of the
+ // start/stop session timestamp (recorded in XprofResponse).
+ if (start_time_ != kUntracedActivity) {
+ if (TraceMeRecorder::Active()) {
+ TraceMeRecorder::Record({kCompleteActivity, std::move(no_init_.name),
+ start_time_, Env::Default()->NowNanos()});
+ }
+ no_init_.name.~string();
+ }
+ }
+
+ // TraceMe is not movable or copyable.
+ TraceMe(const TraceMe &) = delete;
+ TraceMe &operator=(const TraceMe &) = delete;
+
+ // Static API, for use when scoped objects are inconvenient.
+
+ // Record the start time of an activity.
+ // Returns the activity ID, which is used to stop the activity.
+ static uint64 ActivityStart(absl::string_view name, int level = 1) {
+ return TraceMeRecorder::Active(level) ? ActivityStartImpl(name)
+ : kUntracedActivity;
+ }
+
+ // Record the end time of an activity started by ActivityStart().
+ static void ActivityEnd(uint64 activity_id) {
+ // We don't check the level again (see ~TraceMe()).
+ if (activity_id != kUntracedActivity) {
+ if (TraceMeRecorder::Active()) {
+ ActivityEndImpl(activity_id);
+ }
+ }
+ }
+
+ private:
+ // Activity ID or start time used when tracing is disabled.
+ constexpr static uint64 kUntracedActivity = 0;
+ // Activity ID used as a placeholder when both start and end are present.
+ constexpr static uint64 kCompleteActivity = 1;
+
+ static uint64 ActivityStartImpl(absl::string_view activity_name);
+ static void ActivityEndImpl(uint64 activity_id);
+
+ // Wrap the name into a union so that we can avoid the cost of string
+ // initialization when tracing is disabled.
+ union NoInit {
+ NoInit() {}
+ ~NoInit() {}
+ string name;
+ } no_init_;
+
+ uint64 start_time_;
+};
+
+} // namespace profiler
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_PROFILER_LIB_TRACEME_H_
diff --git a/tensorflow/core/profiler/rpc/BUILD b/tensorflow/core/profiler/rpc/BUILD
index 83ec75d..6b4576e 100644
--- a/tensorflow/core/profiler/rpc/BUILD
+++ b/tensorflow/core/profiler/rpc/BUILD
@@ -14,7 +14,7 @@
"//tensorflow/contrib/tpu/profiler:tpu_profiler_proto_cc",
"//tensorflow/core:framework",
"//tensorflow/core/common_runtime/eager:context",
- "//tensorflow/core/profiler/lib:eager_profiler",
+ "//tensorflow/core/profiler/lib:profiler_session",
],
alwayslink = 1,
)
diff --git a/tensorflow/core/profiler/rpc/profiler_server.cc b/tensorflow/core/profiler/rpc/profiler_server.cc
index 08affff..835aa1e 100644
--- a/tensorflow/core/profiler/rpc/profiler_server.cc
+++ b/tensorflow/core/profiler/rpc/profiler_server.cc
@@ -26,7 +26,7 @@
std::unique_ptr<Thread> StartProfilerServer(EagerContext* const eager_context,
int32 port) {
return WrapUnique(eager_context->TFEnv()->StartThread(
- {}, "profiler server", [eager_context, port]() {
+ {}, "profiler_server", [eager_context, port]() {
string server_address = strings::StrCat("0.0.0.0:", port);
std::unique_ptr<TPUProfiler::Service> service =
CreateProfilerService(eager_context);
diff --git a/tensorflow/core/profiler/rpc/profiler_service_impl.cc b/tensorflow/core/profiler/rpc/profiler_service_impl.cc
index bde2ff2..872ef2a 100644
--- a/tensorflow/core/profiler/rpc/profiler_service_impl.cc
+++ b/tensorflow/core/profiler/rpc/profiler_service_impl.cc
@@ -17,7 +17,7 @@
#include "grpcpp/support/status.h"
#include "tensorflow/contrib/tpu/profiler/tpu_profiler.grpc.pb.h"
#include "tensorflow/core/common_runtime/eager/context.h"
-#include "tensorflow/core/profiler/lib/eager_profiler.h"
+#include "tensorflow/core/profiler/lib/profiler_session.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -38,8 +38,8 @@
::grpc::Status Profile(::grpc::ServerContext* ctx, const ProfileRequest* req,
ProfileResponse* response) override {
LOG(INFO) << "Received a profile request.";
- std::unique_ptr<EagerProfiler> profiler =
- EagerProfiler::Create(eager_context_);
+ std::unique_ptr<ProfilerSession> profiler =
+ ProfilerSession::Create(eager_context_);
if (!profiler->Status().ok()) {
return ::grpc::Status(::grpc::StatusCode::INTERNAL,
profiler->Status().error_message());
diff --git a/tensorflow/core/profiler/rpc/profiler_service_impl.h b/tensorflow/core/profiler/rpc/profiler_service_impl.h
index 79dc767..311a267 100644
--- a/tensorflow/core/profiler/rpc/profiler_service_impl.h
+++ b/tensorflow/core/profiler/rpc/profiler_service_impl.h
@@ -20,7 +20,6 @@
#include "grpcpp/support/status.h"
#include "tensorflow/contrib/tpu/profiler/tpu_profiler.grpc.pb.h"
#include "tensorflow/core/common_runtime/eager/context.h"
-#include "tensorflow/core/profiler/lib/eager_profiler.h"
namespace tensorflow {
std::unique_ptr<TPUProfiler::Service> CreateProfilerService(
diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto
index a2cc1bc..44e9854 100644
--- a/tensorflow/core/protobuf/config.proto
+++ b/tensorflow/core/protobuf/config.proto
@@ -156,6 +156,16 @@
// CollectiveReduce, and serves as an override to automatic ring order
// generation in OrderTaskDeviceMap() during CollectiveParam resolution.
string collective_ring_order = 4;
+
+ // If true then extra work is done by GPUDevice and GPUBFCAllocator to
+ // keep track of when GPU memory is freed and when kernels actually
+ // complete so that we can know when a nominally free memory chunk
+ // is really not subject to pending use.
+ bool timestamped_allocator = 5;
+
+ // If > 0 limit the number of pending kernels on any compute
+ // stream to this number.
+ int32 pending_cap = 6;
}
// Everything inside experimental is subject to change and is not subject
@@ -429,6 +439,10 @@
// If true, make collective op execution order sequential and deterministic
// for potentially concurrent collective instances.
bool collective_deterministic_sequential_execution = 6;
+
+ // If true, use NCCL for CollectiveOps. This feature is highly
+ // experimental.
+ bool collective_nccl = 7;
};
Experimental experimental = 16;
diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto
index 17e76c4..b5c9599 100644
--- a/tensorflow/core/protobuf/rewriter_config.proto
+++ b/tensorflow/core/protobuf/rewriter_config.proto
@@ -76,7 +76,7 @@
// Try to allocate some independent Op outputs contiguously in order to
// merge or eliminate downstream Ops (off by default).
Toggle scoped_allocator_optimization = 15;
- // Force small ops onto the CPU (default is OFF).
+ // Force small ops onto the CPU (default is ON).
Toggle pin_to_host_optimization = 18;
// Disable the entire meta optimizer (off by default).
bool disable_meta_optimizer = 19;
diff --git a/tensorflow/contrib/tpu/proto/BUILD b/tensorflow/core/protobuf/tpu/BUILD
similarity index 100%
rename from tensorflow/contrib/tpu/proto/BUILD
rename to tensorflow/core/protobuf/tpu/BUILD
diff --git a/tensorflow/contrib/tpu/proto/compilation_result.proto b/tensorflow/core/protobuf/tpu/compilation_result.proto
similarity index 100%
rename from tensorflow/contrib/tpu/proto/compilation_result.proto
rename to tensorflow/core/protobuf/tpu/compilation_result.proto
diff --git a/tensorflow/contrib/tpu/proto/dynamic_padding.proto b/tensorflow/core/protobuf/tpu/dynamic_padding.proto
similarity index 100%
rename from tensorflow/contrib/tpu/proto/dynamic_padding.proto
rename to tensorflow/core/protobuf/tpu/dynamic_padding.proto
diff --git a/tensorflow/contrib/tpu/proto/optimization_parameters.proto b/tensorflow/core/protobuf/tpu/optimization_parameters.proto
similarity index 89%
rename from tensorflow/contrib/tpu/proto/optimization_parameters.proto
rename to tensorflow/core/protobuf/tpu/optimization_parameters.proto
index bc50c61..a4ca7f3 100644
--- a/tensorflow/contrib/tpu/proto/optimization_parameters.proto
+++ b/tensorflow/core/protobuf/tpu/optimization_parameters.proto
@@ -89,11 +89,11 @@
// the normal version of Adam that updates all parameters in the embedding
// table, even for entries that are not used in the current minibatch
// (https://www.tensorflow.org/api_docs/python/tf/contrib/opt/AdamOptimizer). If
-// use_non_lazy_adam is enabled, use_gradient_accumulation is also required in
-// order to get correct results; a warning will be printed otherwise (which may
-// change to an error in the future). If use_sum_inside_sqrt is set, the Adam
-// variable update formula will be changed from m / (sqrt(v) + epsilon) to
-// m / sqrt(v + epsilon**2); this option improves the performance of TPU
+// use_non_lazy_adam is enabled, gradient accumulation is also required to be
+// enabled in order to get correct results; a warning will be printed otherwise
+// (which may change to an error in the future). If use_sum_inside_sqrt is set,
+// the Adam variable update formula will be changed from m / (sqrt(v) + epsilon)
+// to m / sqrt(v + epsilon**2); this option improves the performance of TPU
// training and is not expected to harm model quality.
message AdamParameters {
float beta1 = 3;
@@ -170,6 +170,20 @@
float initial_accumulator = 3;
}
+// Status of using gradient accumulation (doing two passes over the input
+// gradients: one to accumulate them into a temporary array and another to apply
+// them using the actual optimization algorithm). The extra message is to wrap
+// the enum for scoping.
+message GradientAccumulationStatus {
+ // Defaults to value of use_gradient_accumulation (temporary compatibility
+ // behavior).
+ enum Status {
+ UNSPECIFIED = 0;
+ ENABLED = 1;
+ DISABLED = 2;
+ }
+};
+
message OptimizationParameters {
// Learning rate used for updating the embedding layer parameters.
LearningRate learning_rate = 13;
@@ -191,11 +205,13 @@
// once per minibatch.
float weight_decay_factor = 16;
- // Whether to use gradient accumulation (do two passes over the input
+ // Status of using gradient accumulation (doing two passes over the input
// gradients: one to accumulate them into a temporary array and another to
- // apply them using the actual optimization algorithm). This feature is
- // experimental -- it has not been fully verified and may cause training
- // crashes and/or failures.
+ // apply them using the actual optimization algorithm).
+ GradientAccumulationStatus.Status gradient_accumulation_status = 17;
+
+ // Old gradient accumulation flag; overridden by gradient_accumulation_status
+ // when it is set.
bool use_gradient_accumulation = 15;
// Optimization algorithm parameters; which field is selected determines which
diff --git a/tensorflow/contrib/tpu/proto/topology.proto b/tensorflow/core/protobuf/tpu/topology.proto
similarity index 100%
rename from tensorflow/contrib/tpu/proto/topology.proto
rename to tensorflow/core/protobuf/tpu/topology.proto
diff --git a/tensorflow/contrib/tpu/proto/tpu_embedding_configuration.proto b/tensorflow/core/protobuf/tpu/tpu_embedding_configuration.proto
similarity index 96%
rename from tensorflow/contrib/tpu/proto/tpu_embedding_configuration.proto
rename to tensorflow/core/protobuf/tpu/tpu_embedding_configuration.proto
index da19b13..53280ed 100644
--- a/tensorflow/contrib/tpu/proto/tpu_embedding_configuration.proto
+++ b/tensorflow/core/protobuf/tpu/tpu_embedding_configuration.proto
@@ -2,8 +2,8 @@
package tensorflow.tpu;
-import "tensorflow/contrib/tpu/proto/optimization_parameters.proto";
-import "tensorflow/contrib/tpu/proto/tpu_embedding_output_layout.proto";
+import "tensorflow/core/protobuf/tpu/optimization_parameters.proto";
+import "tensorflow/core/protobuf/tpu/tpu_embedding_output_layout.proto";
message TPUEmbeddingConfiguration {
// Description of the various embedding tables.
diff --git a/tensorflow/contrib/tpu/proto/tpu_embedding_output_layout.proto b/tensorflow/core/protobuf/tpu/tpu_embedding_output_layout.proto
similarity index 100%
rename from tensorflow/contrib/tpu/proto/tpu_embedding_output_layout.proto
rename to tensorflow/core/protobuf/tpu/tpu_embedding_output_layout.proto
diff --git a/tensorflow/core/util/device_name_utils.cc b/tensorflow/core/util/device_name_utils.cc
index cb088fa..56e6188 100644
--- a/tensorflow/core/util/device_name_utils.cc
+++ b/tensorflow/core/util/device_name_utils.cc
@@ -289,6 +289,30 @@
return true;
}
+void DeviceNameUtils::EnsureSpecification(ParsedName* more_specific,
+ const ParsedName& less_specific) {
+ if (less_specific.has_job) {
+ more_specific->has_job = true;
+ more_specific->job = less_specific.job;
+ }
+ if (less_specific.has_replica) {
+ more_specific->has_replica = true;
+ more_specific->replica = less_specific.replica;
+ }
+ if (less_specific.has_task) {
+ more_specific->has_task = true;
+ more_specific->task = less_specific.task;
+ }
+ if (less_specific.has_type) {
+ more_specific->has_type = true;
+ more_specific->type = less_specific.type;
+ }
+ if (less_specific.has_id) {
+ more_specific->has_id = true;
+ more_specific->id = less_specific.id;
+ }
+}
+
/* static */
bool DeviceNameUtils::IsCompleteSpecification(const ParsedName& pattern,
const ParsedName& name) {
diff --git a/tensorflow/core/util/device_name_utils.h b/tensorflow/core/util/device_name_utils.h
index bb5e2b3..b047e81 100644
--- a/tensorflow/core/util/device_name_utils.h
+++ b/tensorflow/core/util/device_name_utils.h
@@ -110,6 +110,11 @@
static bool IsSpecification(const ParsedName& less_specific,
const ParsedName& more_specific);
+ // Makes minimal changes to more_specific so that it becomes a
+ // specification of less_specific.
+ static void EnsureSpecification(ParsedName* more_specific,
+ const ParsedName& less_specific);
+
// Like IsSpecification, but the second argument "name" must have a
// non-wildcard value for all of its components.
static bool IsCompleteSpecification(const ParsedName& pattern,
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index 548f548..91f9bc0 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -17,6 +17,7 @@
#define TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
#ifdef INTEL_MKL
+#include <list>
#include <memory>
#include <string>
#include <unordered_map>
@@ -34,8 +35,7 @@
#endif
#ifdef INTEL_MKL_ML_ONLY
-#error \
- "Compiling for INTEL MKL ML only is no longer supported.Please use MKL DNN (the default option for --config=mkl)"
+#error "Please use INTEL MKL DNN (the default option for --config=mkl)."
#endif
#ifdef INTEL_MKL_ML_ONLY
@@ -86,7 +86,7 @@
// For use with MKL ML, has been deprecated
typedef enum { W = 0, H = 1, C = 2, N = 3 } MklDims;
-// The dimensions order that MKL DNN internally uses for 2D activations
+// The dimensions order that MKL-DNN internally uses for 2D activations
// [Batch, Channel, Height, Width] and
// for 2D filters [Out_Channel, In_Channel, Height, Width].
typedef enum {
@@ -98,7 +98,7 @@
Dim_I = 1
} MklDnnDims;
-// The dimensions order that MKL DNN internally uses for 3D activations
+// The dimensions order that MKL-DNN internally uses for 3D activations
// [Batch, Channel, Depth, Height, Width] and
// for 3D filters [Out_Channel, In_Channel, Depth, Height, Width].
typedef enum {
@@ -130,7 +130,7 @@
TF_3DFILTER_DIM_O = 4
} TFFilterDims3d;
-// The dimensions order that MKL DNN requires for the filter in a grouped
+// The dimensions order that MKL-DNN requires for the filter in a grouped
// convolution (2D only)
typedef enum {
MKL_GROUP_FILTER_DIM_G = 0,
@@ -837,7 +837,6 @@
return mkl_tensor; // return input since it is already TF tensor
TensorShape output_shape = mkl_shape.GetTfShape();
- ;
// Allocate output tensor.
context->allocate_temp(DataTypeToEnum<T>::v(), output_shape,
@@ -2061,6 +2060,111 @@
const mkldnn::memory::dims NONE_DIMS = {};
+//
+// LRUCache is a class which implements LRU (Least Recently Used) cache.
+// The implementation is similar to that of
+// tensorflow/core/platform/cloud/expiring_lru_cache.h
+// without its thread-safe part because the cache is supposed to be
+// used as thread local (for instance, MklPrimitive caching).
+//
+// The LRU list maintains objects in chronological order based on
+// creation time, with the least recently accessed object at the
+// tail of LRU list, while the most recently accessed object
+// at the head of LRU list.
+//
+// This class is used to maintain an upper bound on the total number of
+// cached items. When the cache reaches its capacity, the LRU item will
+// be removed and replaced by a new one from SetOp call.
+//
+template <typename T>
+class LRUCache {
+ public:
+ explicit LRUCache(size_t capacity) {
+ capacity_ = capacity;
+ Clear();
+ }
+
+ T* GetOp(const string& key) {
+ auto it = cache_.find(key);
+ if (it == cache_.end()) {
+ return nullptr;
+ }
+
+ // Move to the front of LRU list as the most recently accessed.
+ lru_list_.erase(it->second.lru_iterator);
+ lru_list_.push_front(it->first);
+ it->second.lru_iterator = lru_list_.begin();
+ return it->second.op;
+ }
+
+ void SetOp(const string& key, T* op) {
+ if (lru_list_.size() >= capacity_) {
+ Delete();
+ }
+
+ // Insert an entry to the front of the LRU list
+ lru_list_.push_front(key);
+ Entry entry(op, lru_list_.begin());
+ cache_.emplace(std::make_pair(key, std::move(entry)));
+ }
+
+ void Clear() {
+ if (lru_list_.empty()) return;
+
+ // Clean up the cache
+ cache_.clear();
+ lru_list_.clear();
+ }
+
+ private:
+ struct Entry {
+ // The entry's value.
+ T* op;
+
+ // A list iterator pointing to the entry's position in the LRU list.
+ std::list<string>::iterator lru_iterator;
+
+ // Constructor
+ Entry(T* op, std::list<string>::iterator it) {
+ this->op = op;
+ this->lru_iterator = it;
+ }
+
+ // Move construcctor
+ Entry(Entry&& source) noexcept
+ : lru_iterator(std::move(source.lru_iterator)) {
+ op = std::move(source.op);
+ source.op = std::forward<T*>(nullptr);
+ }
+
+ // Destructor
+ ~Entry() {
+ if (op != nullptr) delete op;
+ }
+ };
+
+ // Remove the least recently accessed entry from LRU list, which
+ // is the tail of lru_list_. Update cache_ correspondingly.
+ bool Delete() {
+ if (lru_list_.empty()) return false;
+ string key = lru_list_.back();
+ lru_list_.pop_back();
+ cache_.erase(key);
+ return true;
+ }
+
+ // Cache capacity
+ size_t capacity_;
+
+ // The cache, a map from string key to a LRU entry.
+ std::unordered_map<string, Entry> cache_;
+
+ // The LRU list of entries.
+ // The front of the list contains the key of the most recently accessed
+ // entry, while the back of the list is the least recently accessed entry.
+ std::list<string> lru_list_;
+};
+
template <typename T>
class MklPrimitiveFactory {
public:
@@ -2069,23 +2173,13 @@
~MklPrimitiveFactory() {}
MklPrimitive* GetOp(const string& key) {
- auto& map = MklPrimitiveFactory<T>::GetHashMap();
- auto stream_iter = map.find(key);
- if (stream_iter == map.end()) {
- return nullptr;
- } else {
- CHECK(stream_iter->second != nullptr) << "nullptr present in map";
- return stream_iter->second;
- }
+ auto& lru_cache = MklPrimitiveFactory<T>::GetLRUCache();
+ return lru_cache.GetOp(key);
}
void SetOp(const string& key, MklPrimitive* op) {
- auto& map = MklPrimitiveFactory<T>::GetHashMap();
- auto stream_iter = map.find(key);
-
- CHECK(stream_iter == map.end());
-
- map[key] = op;
+ auto& lru_cache = MklPrimitiveFactory<T>::GetLRUCache();
+ lru_cache.SetOp(key, op);
}
/// Function to decide whether HW has AVX512 or AVX2
@@ -2105,9 +2199,10 @@
}
private:
- static inline std::unordered_map<string, MklPrimitive*>& GetHashMap() {
- static thread_local std::unordered_map<string, MklPrimitive*> map_;
- return map_;
+ static inline LRUCache<MklPrimitive>& GetLRUCache() {
+ static const int kCapacity = 1024; // cache capacity
+ static thread_local LRUCache<MklPrimitive> lru_cache_(kCapacity);
+ return lru_cache_;
}
};
diff --git a/tensorflow/core/util/mkl_util_test.cc b/tensorflow/core/util/mkl_util_test.cc
index 4f837f1..bed6feb 100644
--- a/tensorflow/core/util/mkl_util_test.cc
+++ b/tensorflow/core/util/mkl_util_test.cc
@@ -84,6 +84,40 @@
EXPECT_EQ(b_md2.data.format, mkldnn_blocked);
}
+TEST(MklUtilTest, LRUCacheTest) {
+ // The cached objects are of type int*
+ size_t capacity = 100;
+ size_t num_objects = capacity + 10;
+ LRUCache<int> lru_cache(capacity);
+
+ // Test SetOp: be able to set more ops than the capacity
+ for (int k = 0; k < num_objects; k++) {
+ lru_cache.SetOp(std::to_string(k), new int(k));
+ }
+
+ // Test GetOp and capacity:
+ // Least recently accessed objects should not be in cache any more.
+ for (int k = 0; k < num_objects - capacity; ++k) {
+ EXPECT_EQ(nullptr, lru_cache.GetOp(std::to_string(k)));
+ }
+
+ // Test GetOp and capacity:
+ // Most recently accessed objects should still be in cache.
+ for (int k = num_objects - capacity; k < num_objects; ++k) {
+ int* int_ptr = lru_cache.GetOp(std::to_string(k));
+ EXPECT_NE(nullptr, int_ptr);
+ EXPECT_EQ(*int_ptr, k);
+ }
+
+ // Clean up the cache
+ lru_cache.Clear();
+
+ // After clean up, there should be no cached object.
+ for (int k = 0; k < num_objects; ++k) {
+ EXPECT_EQ(nullptr, lru_cache.GetOp(std::to_string(k)));
+ }
+}
+
#endif // INTEL_MKL_ML_ONLY
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/examples/autograph/integration_tests/keras_test.py b/tensorflow/examples/autograph/integration_tests/keras_test.py
index 3fe33df..72b62f1 100644
--- a/tensorflow/examples/autograph/integration_tests/keras_test.py
+++ b/tensorflow/examples/autograph/integration_tests/keras_test.py
@@ -87,18 +87,16 @@
@test_util.run_deprecated_v1
def test_recursive_true(self):
- with self.assertRaisesRegexp(NotImplementedError,
- 'Object conversion is not yet supported.'):
- with tf.Graph().as_default():
- model = CompoundModel()
- model.build(tf.TensorShape((None, 10, 10, 1)))
- init = tf.global_variables_initializer()
+ with tf.Graph().as_default():
+ model = CompoundModel()
+ model.build(tf.TensorShape((None, 10, 10, 1)))
+ init = tf.global_variables_initializer()
- with tf.Session() as sess:
- self.evaluate(init)
- sample_input = tf.random_uniform((1, 10, 10, 1))
- output = model(sample_input) # pylint: disable=not-callable
- self.assertEqual(self.evaluate(output).shape, (1, 3))
+ with tf.Session() as sess:
+ self.evaluate(init)
+ sample_input = tf.random_uniform((1, 10, 10, 1))
+ output = model(sample_input) # pylint: disable=not-callable
+ self.assertEqual(self.evaluate(output).shape, (1, 3))
if __name__ == '__main__':
diff --git a/tensorflow/examples/speech_commands/BUILD b/tensorflow/examples/speech_commands/BUILD
index 7f3c764..ca044e5 100644
--- a/tensorflow/examples/speech_commands/BUILD
+++ b/tensorflow/examples/speech_commands/BUILD
@@ -76,8 +76,25 @@
],
)
+tf_py_test(
+ name = "train_test",
+ size = "small",
+ srcs = ["train_test.py"],
+ additional_deps = [
+ ":train",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
py_binary(
name = "freeze",
+ srcs = ["freeze.py"],
+ srcs_version = "PY2AND3",
+ deps = [":freeze_lib"],
+)
+
+py_library(
+ name = "freeze_lib",
srcs = [
"freeze.py",
],
@@ -103,6 +120,13 @@
py_binary(
name = "wav_to_features",
+ srcs = ["wav_to_features.py"],
+ srcs_version = "PY2AND3",
+ deps = [":wav_to_features_lib"],
+)
+
+py_library(
+ name = "wav_to_features_lib",
srcs = [
"wav_to_features.py",
],
@@ -128,6 +152,13 @@
py_binary(
name = "generate_streaming_test_wav",
+ srcs = ["generate_streaming_test_wav.py"],
+ srcs_version = "PY2AND3",
+ deps = [":generate_streaming_test_wav_lib"],
+)
+
+py_library(
+ name = "generate_streaming_test_wav_lib",
srcs = [
"generate_streaming_test_wav.py",
],
@@ -168,6 +199,13 @@
py_binary(
name = "label_wav",
+ srcs = ["label_wav.py"],
+ srcs_version = "PY2AND3",
+ deps = [":label_wav_lib"],
+)
+
+py_library(
+ name = "label_wav_lib",
srcs = [
"label_wav.py",
],
diff --git a/tensorflow/examples/speech_commands/train_test.py b/tensorflow/examples/speech_commands/train_test.py
new file mode 100644
index 0000000..db19576
--- /dev/null
+++ b/tensorflow/examples/speech_commands/train_test.py
@@ -0,0 +1,144 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for data input for speech commands."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import tensorflow as tf
+
+from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio
+from tensorflow.examples.speech_commands import train
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import test
+
+
+# Used to convert a dictionary into an object, for mocking parsed flags.
+class DictStruct(object):
+
+ def __init__(self, **entries):
+ self.__dict__.update(entries)
+
+
+class TrainTest(test.TestCase):
+
+ def _getWavData(self):
+ with self.cached_session():
+ sample_data = tf.zeros([32000, 2])
+ wav_encoder = contrib_audio.encode_wav(sample_data, 16000)
+ wav_data = self.evaluate(wav_encoder)
+ return wav_data
+
+ def _saveTestWavFile(self, filename, wav_data):
+ with open(filename, 'wb') as f:
+ f.write(wav_data)
+
+ def _saveWavFolders(self, root_dir, labels, how_many):
+ wav_data = self._getWavData()
+ for label in labels:
+ dir_name = os.path.join(root_dir, label)
+ os.mkdir(dir_name)
+ for i in range(how_many):
+ file_path = os.path.join(dir_name, 'some_audio_%d.wav' % i)
+ self._saveTestWavFile(file_path, wav_data)
+
+ def _prepareDummyTrainingData(self):
+ tmp_dir = self.get_temp_dir()
+ wav_dir = os.path.join(tmp_dir, 'wavs')
+ os.mkdir(wav_dir)
+ self._saveWavFolders(wav_dir, ['a', 'b', 'c'], 100)
+ background_dir = os.path.join(wav_dir, '_background_noise_')
+ os.mkdir(background_dir)
+ wav_data = self._getWavData()
+ for i in range(10):
+ file_path = os.path.join(background_dir, 'background_audio_%d.wav' % i)
+ self._saveTestWavFile(file_path, wav_data)
+ return wav_dir
+
+ def _getDefaultFlags(self):
+ flags = {
+ 'data_url': '',
+ 'data_dir': self._prepareDummyTrainingData(),
+ 'wanted_words': 'a,b,c',
+ 'sample_rate': 16000,
+ 'clip_duration_ms': 1000,
+ 'window_size_ms': 30,
+ 'window_stride_ms': 20,
+ 'feature_bin_count': 40,
+ 'preprocess': 'mfcc',
+ 'silence_percentage': 25,
+ 'unknown_percentage': 25,
+ 'validation_percentage': 10,
+ 'testing_percentage': 10,
+ 'summaries_dir': os.path.join(self.get_temp_dir(), 'summaries'),
+ 'train_dir': os.path.join(self.get_temp_dir(), 'train'),
+ 'time_shift_ms': 100,
+ 'how_many_training_steps': '2',
+ 'learning_rate': '0.01',
+ 'quantize': False,
+ 'model_architecture': 'conv',
+ 'check_nans': False,
+ 'start_checkpoint': '',
+ 'batch_size': 1,
+ 'background_volume': 0.25,
+ 'background_frequency': 0.8,
+ 'eval_step_interval': 1,
+ 'save_step_interval': 1,
+ }
+ return DictStruct(**flags)
+
+ @test_util.run_deprecated_v1
+ def testTrain(self):
+ train.FLAGS = self._getDefaultFlags()
+ train.main('')
+ self.assertTrue(
+ gfile.Exists(
+ os.path.join(train.FLAGS.train_dir,
+ train.FLAGS.model_architecture + '.pbtxt')))
+ self.assertTrue(
+ gfile.Exists(
+ os.path.join(train.FLAGS.train_dir,
+ train.FLAGS.model_architecture + '_labels.txt')))
+ self.assertTrue(
+ gfile.Exists(
+ os.path.join(train.FLAGS.train_dir,
+ train.FLAGS.model_architecture + '.ckpt-1.meta')))
+
+ @test_util.run_deprecated_v1
+ def testQuantizedTrain(self):
+ train.FLAGS = self._getDefaultFlags()
+ train.FLAGS.quantize = True
+ train.FLAGS.model_architecture = 'tiny_conv'
+ train.main('')
+ self.assertTrue(
+ gfile.Exists(
+ os.path.join(train.FLAGS.train_dir,
+ train.FLAGS.model_architecture + '.pbtxt')))
+ self.assertTrue(
+ gfile.Exists(
+ os.path.join(train.FLAGS.train_dir,
+ train.FLAGS.model_architecture + '_labels.txt')))
+ self.assertTrue(
+ gfile.Exists(
+ os.path.join(train.FLAGS.train_dir,
+ train.FLAGS.model_architecture + '.ckpt-1.meta')))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 97b6d4b..cc0995f 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -7411,7 +7411,7 @@
// misisng, the `output` tensor at that position will be zeroed.
//
// Read
-// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/sparse#Segmentation)
// for an explanation of segments.
//
// For example:
@@ -7741,7 +7741,7 @@
// Computes the sum along sparse segments of a tensor.
//
// Read
-// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
// for an explanation of segments.
//
// Like `SegmentSum`, but `segment_ids` can have rank less than `data`'s first
@@ -9215,7 +9215,7 @@
// Computes the minimum along segments of a tensor.
//
// Read
-// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#segmentation)
+// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
// for an explanation of segments.
//
// This operator is similar to the unsorted segment sum operator found
@@ -9229,6 +9229,15 @@
// possible value for the specific numeric type,
// `output[i] = numeric_limits<T>::max()`.
//
+// For example:
+//
+// ``` python
+// c = tf.constant([[1,2,3,4], [5,6,7,8], [4,3,2,1]])
+// tf.unsorted_segment_min(c, tf.constant([0, 1, 0]), num_segments=2)
+// # ==> [[ 1, 2, 2, 1],
+// # [5, 6, 7, 8]]
+// ```
+//
// If the given segment ID `i` is negative, then the corresponding value is
// dropped, and will not be included in the result.
//
@@ -9276,6 +9285,57 @@
return op.Output(0)
}
+// TensorArrayGatherV2Attr is an optional argument to TensorArrayGatherV2.
+type TensorArrayGatherV2Attr func(optionalAttr)
+
+// TensorArrayGatherV2ElementShape sets the optional element_shape attribute to value.
+// If not specified, defaults to <unknown_rank:true >
+func TensorArrayGatherV2ElementShape(value tf.Shape) TensorArrayGatherV2Attr {
+ return func(m optionalAttr) {
+ m["element_shape"] = value
+ }
+}
+
+// Deprecated. Use TensorArrayGatherV3
+//
+// DEPRECATED at GraphDef version 26: Use TensorArrayGatherV3
+func TensorArrayGatherV2(scope *Scope, handle tf.Output, indices tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayGatherV2Attr) (value tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"dtype": dtype}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "TensorArrayGatherV2",
+ Input: []tf.Input{
+ handle, indices, flow_in,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Returns the truth value of (x == y) element-wise.
+//
+// *NOTE*: `Equal` supports broadcasting. More about broadcasting
+// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+func Equal(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Equal",
+ Input: []tf.Input{
+ x, y,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes the gradient of morphological 2-D dilation with respect to the input.
//
// Arguments:
@@ -12685,7 +12745,7 @@
// Computes the maximum along segments of a tensor.
//
// Read
-// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
// for an explanation of segments.
//
// This operator is similar to the unsorted segment sum operator found
@@ -12706,6 +12766,16 @@
// <img style="width:100%" src="https://www.tensorflow.org/images/UnsortedSegmentMax.png" alt>
// </div>
//
+// For example:
+//
+// ``` python
+// c = tf.constant([[1,2,3,4], [5,6,7,8], [4,3,2,1]])
+// tf.unsorted_segment_max(c, tf.constant([0, 1, 0]), num_segments=2)
+// # ==> [[ 4, 3, 3, 4],
+// # [5, 6, 7, 8]]
+// ```
+//
+//
// Arguments:
//
// segment_ids: A tensor whose shape is a prefix of `data.shape`.
@@ -22887,6 +22957,34 @@
return op.Output(0)
}
+// Creates a dataset that changes the batch size.
+//
+// Creates a dataset that changes the batch size of the dataset to current batch
+// size // num_workers.
+//
+// Arguments:
+// input_dataset: A variant tensor representing the input dataset.
+// num_workers: A scalar representing the number of workers to distribute this batch across. As
+// a result of this transformation the current batch size would end up being
+// divided by this parameter.
+//
+//
+func ExperimentalRebatchDataset(scope *Scope, input_dataset tf.Output, num_workers tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes}
+ opspec := tf.OpSpec{
+ Type: "ExperimentalRebatchDataset",
+ Input: []tf.Input{
+ input_dataset, num_workers,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes the gradient of the sigmoid of `x` wrt its input.
//
// Specifically, `grad = dy * y * (1 - y)`, where `y = sigmoid(x)`, and
@@ -23987,6 +24085,36 @@
return out_example_state_data, out_delta_sparse_weights, out_delta_dense_weights
}
+// Concats all tensors in the list along the 0th dimension.
+//
+// Requires that all tensors have the same shape except the first dimension.
+//
+// input_handle: The input list.
+// element_shape: The shape of the uninitialized elements in the list. If the first
+// dimension is not -1, it is assumed that all list elements have the same
+// leading dim.
+// leading_dims: The list of leading dims of uninitialized list elements. Used if
+// the leading dim of input_handle.element_shape or the element_shape input arg
+// is not already set.
+// tensor: The concated result.
+// lengths: Output tensor containing sizes of the 0th dimension of tensors in the list, used for computing the gradient.
+//
+func TensorListConcatV2(scope *Scope, input_handle tf.Output, element_shape tf.Output, leading_dims tf.Output, element_dtype tf.DataType) (tensor tf.Output, lengths tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"element_dtype": element_dtype}
+ opspec := tf.OpSpec{
+ Type: "TensorListConcatV2",
+ Input: []tf.Input{
+ input_handle, element_shape, leading_dims,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1)
+}
+
// MatrixTriangularSolveAttr is an optional argument to MatrixTriangularSolve.
type MatrixTriangularSolveAttr func(optionalAttr)
@@ -25865,7 +25993,7 @@
// Computes the sum along segments of a tensor.
//
// Read
-// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
// for an explanation of segments.
//
// Computes a tensor such that
@@ -25878,6 +26006,16 @@
// <img style="width:100%" src="https://www.tensorflow.org/images/SegmentSum.png" alt>
// </div>
//
+// For example:
+//
+// ```
+// c = tf.constant([[1,2,3,4], [4, 3, 2, 1], [5,6,7,8]])
+// tf.segment_sum(c, tf.constant([0, 0, 1]))
+// # ==> [[5, 5, 5, 5],
+// # [5, 6, 7, 8]]
+// ```
+//
+//
// Arguments:
//
// segment_ids: A 1-D tensor whose size is equal to the size of `data`'s
@@ -25902,7 +26040,7 @@
// Computes the mean along segments of a tensor.
//
// Read
-// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
// for an explanation of segments.
//
// Computes a tensor such that
@@ -25916,6 +26054,16 @@
// <img style="width:100%" src="https://www.tensorflow.org/images/SegmentMean.png" alt>
// </div>
//
+// For example:
+//
+// ```
+// c = tf.constant([[1.0,2,3,4], [4, 3, 2, 1], [5,6,7,8]])
+// tf.segment_mean(c, tf.constant([0, 0, 1]))
+// # ==> [[2.5, 2.5, 2.5, 2.5],
+// # [5, 6, 7, 8]]
+// ```
+//
+//
// Arguments:
//
// segment_ids: A 1-D tensor whose size is equal to the size of `data`'s
@@ -25940,7 +26088,7 @@
// Computes the minimum along segments of a tensor.
//
// Read
-// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
// for an explanation of segments.
//
// Computes a tensor such that
@@ -25953,6 +26101,15 @@
// <img style="width:100%" src="https://www.tensorflow.org/images/SegmentMin.png" alt>
// </div>
//
+// For example:
+//
+// ```
+// c = tf.constant([[1,2,3,4], [4, 3, 2, 1], [5,6,7,8]])
+// tf.segment_min(c, tf.constant([0, 0, 1]))
+// # ==> [[1, 2, 2, 1],
+// # [5, 6, 7, 8]]
+// ```
+//
// Arguments:
//
// segment_ids: A 1-D tensor whose size is equal to the size of `data`'s
@@ -25977,7 +26134,7 @@
// Computes the sum along segments of a tensor.
//
// Read
-// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
// for an explanation of segments.
//
// Computes a tensor such that
@@ -25996,6 +26153,14 @@
// <img style="width:100%" src="https://www.tensorflow.org/images/UnsortedSegmentSum.png" alt>
// </div>
//
+// ``` python
+// c = tf.constant([[1,2,3,4], [5,6,7,8], [4,3,2,1]])
+// tf.unsorted_segment_sum(c, tf.constant([0, 1, 0]), num_segments=2)
+// # ==> [[ 5, 5, 5, 5],
+// # [5, 6, 7, 8]]
+// ```
+//
+//
// Arguments:
//
// segment_ids: A tensor whose shape is a prefix of `data.shape`.
@@ -26021,7 +26186,7 @@
// Computes the product along segments of a tensor.
//
// Read
-// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#segmentation)
+// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
// for an explanation of segments.
//
// This operator is similar to the unsorted segment sum operator found
@@ -26032,6 +26197,15 @@
// \\(output_i = \prod_{j...} data[j...]\\) where the product is over tuples
// `j...` such that `segment_ids[j...] == i`.
//
+// For example:
+//
+// ``` python
+// c = tf.constant([[1,2,3,4], [5,6,7,8], [4,3,2,1]])
+// tf.unsorted_segment_prod(c, tf.constant([0, 1, 0]), num_segments=2)
+// # ==> [[ 4, 6, 6, 4],
+// # [5, 6, 7, 8]]
+// ```
+//
// If there is no entry for a given segment ID `i`, it outputs 1.
//
// If the given segment ID `i` is negative, then the corresponding value is
@@ -26061,9 +26235,7 @@
// Computes the mean along sparse segments of a tensor.
//
-// Read
-// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
-// for an explanation of segments.
+// See `tf.sparse.segment_sum` for usage examples.
//
// Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first
// dimension, selecting a subset of dimension 0, specified by `indices`.
@@ -26274,7 +26446,7 @@
// misisng, the `output` tensor at that position will be zeroed.
//
// Read
-// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
// for an explanation of segments.
//
// Arguments:
@@ -26420,9 +26592,8 @@
//
// N is the size of the segment being reduced.
//
-// Read
-// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
-// for an explanation of segments.
+// See `tf.sparse.segment_sum` for usage examples.
+//
//
// Arguments:
//
@@ -26481,7 +26652,7 @@
// misisng, the `output` tensor at that position will be zeroed.
//
// Read
-// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
// for an explanation of segments.
//
// Arguments:
@@ -26825,7 +26996,7 @@
// Computes the maximum along segments of a tensor.
//
// Read
-// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
// for an explanation of segments.
//
// Computes a tensor such that
@@ -26838,6 +27009,16 @@
// <img style="width:100%" src="https://www.tensorflow.org/images/SegmentMax.png" alt>
// </div>
//
+// For example:
+//
+// ```
+// c = tf.constant([[1,2,3,4], [4, 3, 2, 1], [5,6,7,8]])
+// tf.segment_max(c, tf.constant([0, 0, 1]))
+// # ==> [[4, 3, 3, 4],
+// # [5, 6, 7, 8]]
+// ```
+//
+//
// Arguments:
//
// segment_ids: A 1-D tensor whose size is equal to the size of `data`'s
@@ -27758,6 +27939,30 @@
return scope.AddOperation(opspec)
}
+// Creates a `Dataset` that includes only 1/`num_shards` of this dataset.
+//
+// Arguments:
+//
+// num_shards: An integer representing the number of shards operating in parallel.
+// index: An integer representing the current worker index.
+//
+//
+func ShardDataset(scope *Scope, input_dataset tf.Output, num_shards tf.Output, index tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"output_types": output_types, "output_shapes": output_shapes}
+ opspec := tf.OpSpec{
+ Type: "ShardDataset",
+ Input: []tf.Input{
+ input_dataset, num_shards, index,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Creates a dataset that batches and pads `batch_size` elements from the input.
//
// Arguments:
@@ -29125,6 +29330,103 @@
return op.Output(0)
}
+// Deprecated. Use TensorArrayScatterV3
+//
+// DEPRECATED at GraphDef version 26: Use TensorArrayScatterV3
+func TensorArrayScatterV2(scope *Scope, handle tf.Output, indices tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "TensorArrayScatterV2",
+ Input: []tf.Input{
+ handle, indices, value, flow_in,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// AsStringAttr is an optional argument to AsString.
+type AsStringAttr func(optionalAttr)
+
+// AsStringPrecision sets the optional precision attribute to value.
+//
+// value: The post-decimal precision to use for floating point numbers.
+// Only used if precision > -1.
+// If not specified, defaults to -1
+func AsStringPrecision(value int64) AsStringAttr {
+ return func(m optionalAttr) {
+ m["precision"] = value
+ }
+}
+
+// AsStringScientific sets the optional scientific attribute to value.
+//
+// value: Use scientific notation for floating point numbers.
+// If not specified, defaults to false
+func AsStringScientific(value bool) AsStringAttr {
+ return func(m optionalAttr) {
+ m["scientific"] = value
+ }
+}
+
+// AsStringShortest sets the optional shortest attribute to value.
+//
+// value: Use shortest representation (either scientific or standard) for
+// floating point numbers.
+// If not specified, defaults to false
+func AsStringShortest(value bool) AsStringAttr {
+ return func(m optionalAttr) {
+ m["shortest"] = value
+ }
+}
+
+// AsStringWidth sets the optional width attribute to value.
+//
+// value: Pad pre-decimal numbers to this width.
+// Applies to both floating point and integer numbers.
+// Only used if width > -1.
+// If not specified, defaults to -1
+func AsStringWidth(value int64) AsStringAttr {
+ return func(m optionalAttr) {
+ m["width"] = value
+ }
+}
+
+// AsStringFill sets the optional fill attribute to value.
+//
+// value: The value to pad if width > -1. If empty, pads with spaces.
+// Another typical value is '0'. String cannot be longer than 1 character.
+// If not specified, defaults to ""
+func AsStringFill(value string) AsStringAttr {
+ return func(m optionalAttr) {
+ m["fill"] = value
+ }
+}
+
+// Converts each entry in the given tensor to strings. Supports many numeric
+//
+// types and boolean.
+func AsString(scope *Scope, input tf.Output, optional ...AsStringAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "AsString",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Returns a `RaggedTensor` containing the specified sequences of numbers.
//
//
@@ -29880,7 +30182,7 @@
// Computes the product along segments of a tensor.
//
// Read
-// [the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+// [the section on segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
// for an explanation of segments.
//
// Computes a tensor such that
@@ -29893,6 +30195,16 @@
// <img style="width:100%" src="https://www.tensorflow.org/images/SegmentProd.png" alt>
// </div>
//
+// For example:
+//
+// ```
+// c = tf.constant([[1,2,3,4], [4, 3, 2, 1], [5,6,7,8]])
+// tf.segment_prod(c, tf.constant([0, 0, 1]))
+// # ==> [[4, 6, 6, 4],
+// # [5, 6, 7, 8]]
+// ```
+//
+//
// Arguments:
//
// segment_ids: A 1-D tensor whose size is equal to the size of `data`'s
@@ -33808,7 +34120,8 @@
// Arguments:
// bytes: Tensor of serialized protos with shape `batch_shape`.
// message_type: Name of the proto message type to decode.
-// field_names: List of strings containing proto field names.
+// field_names: List of strings containing proto field names. An extension field can be decoded
+// by using its full name, e.g. EXT_PACKAGE.EXT_FIELD_NAME.
// output_types: List of TF types to use for the respective field in field_names.
//
// Returns Tensor of int32 with shape `[batch_shape, len(field_names)]`.
@@ -34403,57 +34716,6 @@
return op.Output(0)
}
-// Returns the truth value of (x == y) element-wise.
-//
-// *NOTE*: `Equal` supports broadcasting. More about broadcasting
-// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-func Equal(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Equal",
- Input: []tf.Input{
- x, y,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// TensorArrayGatherV2Attr is an optional argument to TensorArrayGatherV2.
-type TensorArrayGatherV2Attr func(optionalAttr)
-
-// TensorArrayGatherV2ElementShape sets the optional element_shape attribute to value.
-// If not specified, defaults to <unknown_rank:true >
-func TensorArrayGatherV2ElementShape(value tf.Shape) TensorArrayGatherV2Attr {
- return func(m optionalAttr) {
- m["element_shape"] = value
- }
-}
-
-// Deprecated. Use TensorArrayGatherV3
-//
-// DEPRECATED at GraphDef version 26: Use TensorArrayGatherV3
-func TensorArrayGatherV2(scope *Scope, handle tf.Output, indices tf.Output, flow_in tf.Output, dtype tf.DataType, optional ...TensorArrayGatherV2Attr) (value tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"dtype": dtype}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "TensorArrayGatherV2",
- Input: []tf.Input{
- handle, indices, flow_in,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Interleave the values from the `data` tensors into a single tensor.
//
// Builds a merged tensor such that
@@ -35583,100 +35845,3 @@
op := scope.AddOperation(opspec)
return op.Output(0)
}
-
-// AsStringAttr is an optional argument to AsString.
-type AsStringAttr func(optionalAttr)
-
-// AsStringPrecision sets the optional precision attribute to value.
-//
-// value: The post-decimal precision to use for floating point numbers.
-// Only used if precision > -1.
-// If not specified, defaults to -1
-func AsStringPrecision(value int64) AsStringAttr {
- return func(m optionalAttr) {
- m["precision"] = value
- }
-}
-
-// AsStringScientific sets the optional scientific attribute to value.
-//
-// value: Use scientific notation for floating point numbers.
-// If not specified, defaults to false
-func AsStringScientific(value bool) AsStringAttr {
- return func(m optionalAttr) {
- m["scientific"] = value
- }
-}
-
-// AsStringShortest sets the optional shortest attribute to value.
-//
-// value: Use shortest representation (either scientific or standard) for
-// floating point numbers.
-// If not specified, defaults to false
-func AsStringShortest(value bool) AsStringAttr {
- return func(m optionalAttr) {
- m["shortest"] = value
- }
-}
-
-// AsStringWidth sets the optional width attribute to value.
-//
-// value: Pad pre-decimal numbers to this width.
-// Applies to both floating point and integer numbers.
-// Only used if width > -1.
-// If not specified, defaults to -1
-func AsStringWidth(value int64) AsStringAttr {
- return func(m optionalAttr) {
- m["width"] = value
- }
-}
-
-// AsStringFill sets the optional fill attribute to value.
-//
-// value: The value to pad if width > -1. If empty, pads with spaces.
-// Another typical value is '0'. String cannot be longer than 1 character.
-// If not specified, defaults to ""
-func AsStringFill(value string) AsStringAttr {
- return func(m optionalAttr) {
- m["fill"] = value
- }
-}
-
-// Converts each entry in the given tensor to strings. Supports many numeric
-//
-// types and boolean.
-func AsString(scope *Scope, input tf.Output, optional ...AsStringAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "AsString",
- Input: []tf.Input{
- input,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Deprecated. Use TensorArrayScatterV3
-//
-// DEPRECATED at GraphDef version 26: Use TensorArrayScatterV3
-func TensorArrayScatterV2(scope *Scope, handle tf.Output, indices tf.Output, value tf.Output, flow_in tf.Output) (flow_out tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "TensorArrayScatterV2",
- Input: []tf.Input{
- handle, indices, value, flow_in,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
diff --git a/tensorflow/lite/BUILD b/tensorflow/lite/BUILD
index 2cc661f..86eca1d 100644
--- a/tensorflow/lite/BUILD
+++ b/tensorflow/lite/BUILD
@@ -4,9 +4,9 @@
licenses(["notice"]) # Apache 2.0
-load("//tensorflow:tensorflow.bzl", "tf_cc_test")
-load("//tensorflow/lite:build_def.bzl", "tflite_copts")
-load("//tensorflow:tensorflow.bzl", "if_not_windows")
+load("//tensorflow:tensorflow.bzl", "if_not_windows", "tf_cc_test")
+load("//tensorflow/lite:build_def.bzl", "tflite_cc_shared_object", "tflite_copts")
+load("//tensorflow/lite:special_rules.bzl", "tflite_portable_test_suite")
exports_files(glob([
"testdata/*.bin",
@@ -173,26 +173,17 @@
"stderr_reporter.h",
],
copts = tflite_copts() + TFLITE_DEFAULT_COPTS,
- linkopts = [
- ] + select({
- "//tensorflow:android": [
- "-llog",
- ],
- "//conditions:default": [
- ],
- }),
deps = [
":arena_planner",
":graph_info",
":memory_planner",
+ ":minimal_logging",
":schema_fbs_version",
":simple_memory_arena",
":string",
":util",
"//tensorflow/lite/c:c_api_internal",
- "//tensorflow/lite/core/api:api",
- "//tensorflow/lite/kernels:eigen_support",
- "//tensorflow/lite/kernels:gemm_support",
+ "//tensorflow/lite/core/api",
"//tensorflow/lite/nnapi:nnapi_implementation",
"//tensorflow/lite/profiling:profiler",
"//tensorflow/lite/schema:schema_fbs",
@@ -219,6 +210,9 @@
name = "string_util_test",
size = "small",
srcs = ["string_util_test.cc"],
+ tags = [
+ "tflite_not_portable_ios", # TODO(b/117786830)
+ ],
deps = [
":framework",
":string_util",
@@ -233,10 +227,12 @@
name = "interpreter_test",
size = "small",
srcs = ["interpreter_test.cc"],
+ tags = [
+ "tflite_not_portable_ios", # TODO(b/117786830)
+ ],
deps = [
":framework",
":string_util",
- "//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/core/api",
"//tensorflow/lite/kernels:builtin_ops",
"//tensorflow/lite/kernels:kernel_util",
@@ -252,6 +248,9 @@
name = "graph_info_test",
size = "small",
srcs = ["graph_info_test.cc"],
+ tags = [
+ "tflite_not_portable_ios", # TODO(b/117786830)
+ ],
deps = [
":framework",
"//tensorflow/lite/testing:util",
@@ -264,6 +263,9 @@
name = "simple_memory_arena_test",
size = "small",
srcs = ["simple_memory_arena_test.cc"],
+ tags = [
+ "tflite_not_portable_ios", # TODO(b/117786830)
+ ],
deps = [
":simple_memory_arena",
"//tensorflow/lite/testing:util",
@@ -284,9 +286,11 @@
"testdata/test_model.bin",
"testdata/test_model_broken.bin",
],
+ tags = [
+ "tflite_not_portable",
+ ],
deps = [
":framework",
- "//tensorflow/lite/c:c_api_internal",
"//tensorflow/lite/core/api",
"//tensorflow/lite/kernels:builtin_ops",
"//tensorflow/lite/testing:util",
@@ -323,6 +327,9 @@
name = "mutable_op_resolver_test",
size = "small",
srcs = ["mutable_op_resolver_test.cc"],
+ tags = [
+ "tflite_not_portable_ios", # TODO(b/117786830)
+ ],
deps = [
":framework",
"//tensorflow/lite/testing:util",
@@ -344,9 +351,77 @@
name = "util_test",
size = "small",
srcs = ["util_test.cc"],
+ tags = [
+ "tflite_not_portable_ios", # TODO(b/117786830)
+ ],
deps = [
":util",
"//tensorflow/lite/c:c_api_internal",
"@com_google_googletest//:gtest",
],
)
+
+cc_library(
+ name = "minimal_logging",
+ srcs = [
+ "minimal_logging.cc",
+ ] + select({
+ "//tensorflow:android": [
+ "minimal_logging_android.cc",
+ ],
+ "//tensorflow:ios": [
+ "minimal_logging_ios.cc",
+ ],
+ "//conditions:default": [
+ "minimal_logging_default.cc",
+ ],
+ }),
+ hdrs = ["minimal_logging.h"],
+ copts = TFLITE_DEFAULT_COPTS + tflite_copts(),
+ linkopts = select({
+ "//tensorflow:android": ["-llog"],
+ "//conditions:default": [],
+ }),
+ visibility = ["//visibility:private"],
+)
+
+cc_test(
+ name = "minimal_logging_test",
+ size = "small",
+ srcs = ["minimal_logging_test.cc"],
+ tags = [
+ "tflite_not_portable_ios", # TODO(b/117786830)
+ ],
+ deps = [
+ ":minimal_logging",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+# Shared lib target for convenience, pulls in the core runtime and builtin ops.
+# Note: This target is not yet finalized, and the exact set of exported (C/C++)
+# APIs is subject to change.
+tflite_cc_shared_object(
+ name = "libtensorflowlite.so",
+ linkopts = select({
+ "//tensorflow:darwin": [
+ "-Wl,-exported_symbols_list", # This line must be directly followed by the exported_symbols.lds file
+ "$(location //tensorflow/lite:tflite_exported_symbols.lds)",
+ "-Wl,-install_name,@rpath/libtensorflowlite.so",
+ ],
+ "//tensorflow:windows": [],
+ "//conditions:default": [
+ "-z defs",
+ "-Wl,--version-script", # This line must be directly followed by the version_script.lds file
+ "$(location //tensorflow/lite:tflite_version_script.lds)",
+ ],
+ }),
+ deps = [
+ ":framework",
+ ":tflite_exported_symbols.lds",
+ ":tflite_version_script.lds",
+ "//tensorflow/lite/kernels:builtin_ops",
+ ],
+)
+
+tflite_portable_test_suite()
diff --git a/tensorflow/lite/build_def.bzl b/tensorflow/lite/build_def.bzl
index f4c9e4e..88a8faf 100644
--- a/tensorflow/lite/build_def.bzl
+++ b/tensorflow/lite/build_def.bzl
@@ -2,6 +2,7 @@
load(
"//tensorflow:tensorflow.bzl",
+ "tf_binary_additional_srcs",
"tf_cc_shared_object",
"tf_cc_test",
)
@@ -157,7 +158,7 @@
"""
toco_cmdline = " ".join([
- "//tensorflow/lite/toco:toco",
+ "$(location //tensorflow/lite/toco:toco)",
"--input_format=TENSORFLOW_GRAPHDEF",
"--output_format=TFLITE",
("--input_file=$(location %s)" % src),
@@ -168,7 +169,7 @@
srcs = [src],
outs = [out],
cmd = toco_cmdline,
- tools = ["//tensorflow/lite/toco:toco"],
+ tools = ["//tensorflow/lite/toco:toco"] + tf_binary_additional_srcs(),
)
def tflite_to_json(name, src, out):
@@ -225,6 +226,7 @@
return [
"abs",
"add",
+ "add_n",
"arg_min_max",
"avg_pool",
"batch_to_space_nd",
diff --git a/tensorflow/lite/builtin_ops.h b/tensorflow/lite/builtin_ops.h
index b6ffb82..1915565 100644
--- a/tensorflow/lite/builtin_ops.h
+++ b/tensorflow/lite/builtin_ops.h
@@ -131,6 +131,8 @@
kTfLiteBuiltinUnique = 103,
kTfLiteBuiltinCeil = 104,
kTfLiteBuiltinReverseV2 = 105,
+ kTfLiteBuiltinAddN = 106,
+ kTfLiteBuiltinGatherNd = 107,
} TfLiteBuiltinOperator;
#ifdef __cplusplus
diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc
index 3a74b1e..72667d4 100644
--- a/tensorflow/lite/core/api/flatbuffer_conversions.cc
+++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc
@@ -727,6 +727,8 @@
case BuiltinOperator_RANGE:
case BuiltinOperator_SQUARED_DIFFERENCE:
case BuiltinOperator_REVERSE_V2:
+ case BuiltinOperator_ADD_N:
+ case BuiltinOperator_GATHER_ND:
break;
}
return kTfLiteOk;
diff --git a/tensorflow/lite/core/subgraph.cc b/tensorflow/lite/core/subgraph.cc
index ab83456..2fdafa3 100644
--- a/tensorflow/lite/core/subgraph.cc
+++ b/tensorflow/lite/core/subgraph.cc
@@ -397,6 +397,10 @@
check_cancelled_func_ = check_cancelled_func;
}
+void Subgraph::ReserveNodes(int count) {
+ nodes_and_registration_.reserve(count);
+}
+
TfLiteStatus Subgraph::CheckTensorIndices(const char* label, const int* indices,
int length) {
// Making sure kOptionalTensor is not re-defined to something other than -1.
@@ -410,7 +414,9 @@
continue;
}
if (index < 0 || static_cast<size_t>(index) >= context_->tensors_size) {
- ReportError("Invalid tensor index %d in %s\n", index, label);
+ ReportError(
+ "Invalid tensor index %d in %s. The subgraph has %d tensors\n", index,
+ label, context_->tensors_size);
consistent_ = false;
return kTfLiteError;
}
diff --git a/tensorflow/lite/core/subgraph.h b/tensorflow/lite/core/subgraph.h
index 5ca2977..5db15a1 100644
--- a/tensorflow/lite/core/subgraph.h
+++ b/tensorflow/lite/core/subgraph.h
@@ -59,6 +59,11 @@
// interpreter.
TfLiteStatus SetVariables(std::vector<int> variables);
+ // Ensure the internal node storage memory allocates at least `count`
+ // spots for node. NOTE, this doesn't actually add operators. This is an
+ // efficiency optimization that is subject to change.
+ void ReserveNodes(int count);
+
// Adds a node with the given parameters and returns the index of the new
// node in `node_index` (optionally). Interpreter will take ownership of
// `builtin_data` and destroy it with `free`. Ownership of 'init_data'
@@ -68,33 +73,48 @@
const char* init_data,
size_t init_data_size, void* builtin_data,
const TfLiteRegistration* registration,
- int* node_index);
+ int* node_index = nullptr);
// Adds `tensors_to_add` tensors, preserving pre-existing Tensor entries.
// The value pointed to by `first_new_tensor_index` will be set to the
// index of the first new tensor if `first_new_tensor_index` is non-null.
- TfLiteStatus AddTensors(int tensors_to_add, int* first_new_tensor_index);
+ TfLiteStatus AddTensors(int tensors_to_add,
+ int* first_new_tensor_index = nullptr);
// Set description of inputs/outputs/data/fptrs for node `node_index`.
// This variant assumes an external buffer has been allocated of size
// bytes. The lifetime of buffer must be ensured to be greater or equal
// to Interpreter.
- TfLiteStatus SetTensorParametersReadOnly(int tensor_index, TfLiteType type,
- const char* name, const size_t rank,
- const int* dims,
- TfLiteQuantization quantization,
- const char* buffer, size_t bytes,
- const Allocation* allocation);
+ inline TfLiteStatus SetTensorParametersReadOnly(
+ int tensor_index, TfLiteType type, const char* name,
+ const std::vector<int>& dims, TfLiteQuantization quantization,
+ const char* buffer, size_t bytes,
+ const Allocation* allocation = nullptr) {
+ return SetTensorParametersReadOnly(tensor_index, type, name, dims.size(),
+ dims.data(), quantization, buffer, bytes,
+ allocation);
+ }
+ TfLiteStatus SetTensorParametersReadOnly(
+ int tensor_index, TfLiteType type, const char* name, const size_t rank,
+ const int* dims, TfLiteQuantization quantization, const char* buffer,
+ size_t bytes, const Allocation* allocation = nullptr);
// Set description of inputs/outputs/data/fptrs for node `node_index`.
// This variant assumes an external buffer has been allocated of size
// bytes. The lifetime of buffer must be ensured to be greater or equal
// to Interpreter.
+ inline TfLiteStatus SetTensorParametersReadWrite(
+ int tensor_index, TfLiteType type, const char* name,
+ const std::vector<int>& dims, TfLiteQuantization quantization,
+ bool is_variable = false) {
+ return SetTensorParametersReadWrite(tensor_index, type, name, dims.size(),
+ dims.data(), quantization, is_variable);
+ }
TfLiteStatus SetTensorParametersReadWrite(int tensor_index, TfLiteType type,
const char* name, const size_t rank,
const int* dims,
TfLiteQuantization quantization,
- bool is_variable);
+ bool is_variable = false);
// WARNING: Experimental interface, subject to change
// Overrides execution plan. This bounds checks indices sent in.
diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
index d5d3194..86fe7c5 100644
--- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
+++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc
@@ -51,8 +51,47 @@
} while (0)
namespace {
+
+bool IsFloat(TfLiteType type) {
+ switch (type) {
+ case kTfLiteFloat32:
+ return true;
+ default:
+ return false;
+ }
+}
+
+bool IsQuantized(TfLiteType type) {
+ switch (type) {
+ case kTfLiteUInt8:
+ case kTfLiteInt8:
+ case kTfLiteInt16:
+ return true;
+ default:
+ return false;
+ }
+}
+
+bool IsHybridOperator(const TfLiteContext* context, int builtin_code,
+ const TfLiteNode* node) {
+ switch (builtin_code) {
+ case kTfLiteBuiltinConv2d:
+ case kTfLiteBuiltinFullyConnected: {
+ const int input_id = node->inputs->data[0];
+ const int filter_id = node->inputs->data[1];
+ const TfLiteType input_type = context->tensors[input_id].type;
+ const TfLiteType filter_type = context->tensors[filter_id].type;
+ return IsFloat(input_type) && IsQuantized(filter_type);
+ }
+ default:
+ return false;
+ }
+}
+
constexpr int32_t kMinSdkVersionForNNAPI = 27;
constexpr int32_t kMinSdkVersionForNNAPI11 = 28;
+constexpr int32_t kMinSdkVersionForNNAPI12 = 29;
+
} // namespace
// RAII NN API Model Destructor for use with std::unique_ptr
@@ -147,16 +186,42 @@
std::vector<int> lite_tensor_to_ann_tensor_;
};
+class DequantizeMapping {
+ public:
+ int DequantizedAnnIndex(int ann_index, TfLiteType type) const {
+ for (const auto& element : mapping_) {
+ if (ann_index == std::get<0>(element) && type == std::get<1>(element)) {
+ return std::get<2>(element);
+ }
+ }
+ return -1;
+ }
+
+ void Add(int ann_index, TfLiteType type, int dequantized_ann_index) {
+ // This assumes it is not already mapped.
+ mapping_.emplace_back(ann_index, type, dequantized_ann_index);
+ }
+
+ private:
+ // Each tuple specifies the ANN (quantized) tensor index, the desired
+ // floating-point type and the matching ANN (dequantized) tensor index. This
+ // could use a map but instead std::vector is used to keep code size lower.
+ std::vector<std::tuple<int, TfLiteType, int>> mapping_;
+};
+
// Abstract builder for building an op in the NN API graph. This handles
// the disparity between TFLite and NN API operand types. NN API has singular
// operands for both tensors and parameters, and TFLite separates the two.
class NNAPIOpBuilder {
public:
NNAPIOpBuilder(const NnApi* nnapi, TfLiteContext* context,
- OperandMapping* tensor_mapping, ANeuralNetworksModel* nn_model)
+ OperandMapping* tensor_mapping,
+ DequantizeMapping* dequantize_mapping,
+ ANeuralNetworksModel* nn_model)
: nnapi_(nnapi),
context_(context),
operand_mapping_(tensor_mapping),
+ dequantize_mapping_(dequantize_mapping),
nn_model_(nn_model) {}
TfLiteStatus AddScalarInt32Operand(int32_t value) {
@@ -190,50 +255,129 @@
return kTfLiteOk;
}
- TfLiteStatus AddTensorInput(int tensor_index) {
- int ann_index;
- TF_LITE_ENSURE_STATUS(AddTensor(tensor_index, &ann_index));
- augmented_inputs_.push_back(ann_index);
- return kTfLiteOk;
+ TfLiteStatus AddTensorInput(int tensor_index, bool hybrid_op) {
+ return AddTensor(tensor_index, hybrid_op, &augmented_inputs_);
}
TfLiteStatus AddTensorOutput(int tensor_index) {
- int ann_index;
- TF_LITE_ENSURE_STATUS(AddTensor(tensor_index, &ann_index));
- augmented_outputs_.push_back(ann_index);
- return kTfLiteOk;
+ return AddTensor(tensor_index, /*hybrid_op=*/false, &augmented_outputs_);
}
TfLiteStatus AddAdditionalFloat32OutputTensor(uint32_t dimension_count) {
std::vector<uint32_t> dims(dimension_count, 0);
- ANeuralNetworksOperandType operand_type{
- .type = ANEURALNETWORKS_TENSOR_FLOAT32,
- .dimensionCount = dimension_count,
- .dimensions = dims.data()};
- RETURN_TFLITE_ERROR_IF_NN_ERROR(
- context_,
- nnapi_->ANeuralNetworksModel_addOperand(nn_model_, &operand_type));
- int ann_operand = operand_mapping_->add_new_non_tensor_operand();
- augmented_outputs_.push_back(ann_operand);
- return kTfLiteOk;
+ return AddFloat32OutputTensor(dimension_count, dims.data(), nullptr);
}
TfLiteStatus AddStateFloat32Tensor(int tensor_index,
int* ann_tensor_index_out) {
TfLiteTensor* tensor = &context_->tensors[tensor_index];
- int ann_index = operand_mapping_->add_new_non_tensor_operand();
+ return AddFloat32OutputTensor(
+ tensor->dims->size, reinterpret_cast<uint32_t*>(tensor->dims->data),
+ ann_tensor_index_out);
+ }
- ANeuralNetworksOperandType operand_type{
- ANEURALNETWORKS_TENSOR_FLOAT32,
- static_cast<uint32_t>(tensor->dims->size),
- reinterpret_cast<uint32_t*>(tensor->dims->data), tensor->params.scale,
- tensor->params.zero_point};
+ // Adds a Dequantize operator and replaces the input tensor index with the
+ // dequantized version. If the dequantized version of the operator already
+ // exists then it is not added again.
+ TfLiteStatus AddDequantize(int nn_input_index, int lite_index,
+ TfLiteType dequantized_type) {
+ const int ann_index = operand_mapping_->lite_index_to_ann(lite_index);
+ int dequantized_ann_index =
+ dequantize_mapping_->DequantizedAnnIndex(ann_index, dequantized_type);
+
+ if (dequantized_ann_index == -1) {
+ // The dequantized version does not exist yet, it has to be added: a new
+ // Dequantize operation is added, yielding a new tensor.
+ const TfLiteTensor& tensor = context_->tensors[lite_index];
+ ANeuralNetworksOperandType operand_type{
+ dequantized_type, static_cast<uint32_t>(tensor.dims->size),
+ reinterpret_cast<uint32_t*>(tensor.dims->data), 0.f, 0};
+ RETURN_TFLITE_ERROR_IF_NN_ERROR(
+ context_,
+ nnapi_->ANeuralNetworksModel_addOperand(nn_model_, &operand_type));
+ dequantized_ann_index = operand_mapping_->add_new_non_tensor_operand();
+
+ // Add Dequantize operation.
+ const uint32_t dequantize_input[1] = {static_cast<uint32_t>(ann_index)};
+ const uint32_t dequantize_output[1] = {
+ static_cast<uint32_t>(dequantized_ann_index)};
+ RETURN_TFLITE_ERROR_IF_NN_ERROR(
+ context_, nnapi_->ANeuralNetworksModel_addOperation(
+ nn_model_, ANEURALNETWORKS_DEQUANTIZE, 1,
+ dequantize_input, 1, dequantize_output));
+ dequantize_mapping_->Add(ann_index, dequantized_type,
+ dequantized_ann_index);
+ }
+
+ // The input for the original operation is modified so that the operation
+ // now uses the dequantized tensor as input.
+ augmented_inputs_[nn_input_index] = dequantized_ann_index;
+
+ return kTfLiteOk;
+ }
+
+ // Finish emitting the op (of type `type`) into the NN API.
+ TfLiteStatus FinalizeAddOperation(ANeuralNetworksOperationType type) {
+ // Actually add a NN API operation
+ RETURN_TFLITE_ERROR_IF_NN_ERROR(
+ context_,
+ nnapi_->ANeuralNetworksModel_addOperation(
+ nn_model_, type, static_cast<uint32_t>(augmented_inputs_.size()),
+ augmented_inputs_.data(),
+ static_cast<uint32_t>(augmented_outputs_.size()),
+ augmented_outputs_.data()));
+ augmented_inputs_.clear();
+ augmented_outputs_.clear();
+ return kTfLiteOk;
+ }
+
+ private:
+ template <typename T>
+ TfLiteStatus AddScalarOperand(T value, int32_t nn_type) {
+ ANeuralNetworksOperandType operand_type{.type = nn_type};
RETURN_TFLITE_ERROR_IF_NN_ERROR(
context_,
nnapi_->ANeuralNetworksModel_addOperand(nn_model_, &operand_type));
- augmented_outputs_.push_back(ann_index);
+ const int ann_index = operand_mapping_->add_new_non_tensor_operand();
+ RETURN_TFLITE_ERROR_IF_NN_ERROR(
+ context_, nnapi_->ANeuralNetworksModel_setOperandValue(
+ nn_model_, ann_index, &value, sizeof(T)));
+ augmented_inputs_.push_back(ann_index);
+ return kTfLiteOk;
+ }
- *ann_tensor_index_out = ann_index;
+ template <typename T>
+ TfLiteStatus AddVectorOperand(const T* values, uint32_t num_values,
+ int32_t nn_type) {
+ ANeuralNetworksOperandType operand_type{
+ .type = nn_type, .dimensionCount = 1, .dimensions = &num_values};
+
+ RETURN_TFLITE_ERROR_IF_NN_ERROR(
+ context_,
+ nnapi_->ANeuralNetworksModel_addOperand(nn_model_, &operand_type));
+
+ const int ann_index = operand_mapping_->add_new_non_tensor_operand();
+ RETURN_TFLITE_ERROR_IF_NN_ERROR(
+ context_, nnapi_->ANeuralNetworksModel_setOperandValue(
+ nn_model_, ann_index, values, sizeof(T) * num_values));
+ augmented_inputs_.push_back(ann_index);
+ return kTfLiteOk;
+ }
+
+ TfLiteStatus AddFloat32OutputTensor(uint32_t dimension_count,
+ const uint32_t* dimension_data,
+ int* ann_index_out) {
+ ANeuralNetworksOperandType operand_type{
+ .type = ANEURALNETWORKS_TENSOR_FLOAT32,
+ .dimensionCount = dimension_count,
+ .dimensions = dimension_data,
+ };
+ RETURN_TFLITE_ERROR_IF_NN_ERROR(
+ context_,
+ nnapi_->ANeuralNetworksModel_addOperand(nn_model_, &operand_type));
+ const int ann_index = operand_mapping_->add_new_non_tensor_operand();
+ augmented_outputs_.push_back(ann_index);
+ if (ann_index_out) *ann_index_out = ann_index;
return kTfLiteOk;
}
@@ -241,10 +385,11 @@
// This returns the NN API tensor index corresponding to the created tensor.
// If another caller previously created a NN API tensor for `tensor_index`
// then the existing one is returned.
- TfLiteStatus AddTensor(int tensor_index, int* ann_tensor_index_out) {
+ TfLiteStatus AddTensor(int tensor_index, bool hybrid_op,
+ std::vector<uint32_t>* indices) {
int ann_tensor_index = operand_mapping_->lite_index_to_ann(tensor_index);
if (ann_tensor_index != -1) {
- *ann_tensor_index_out = ann_tensor_index;
+ indices->push_back(ann_tensor_index);
return kTfLiteOk;
}
// Allocate a new tensor index
@@ -255,11 +400,17 @@
float scale = 0.0f;
int32_t zeroPoint = 0;
TfLiteTensor* tensor = &context_->tensors[tensor_index];
- switch (tensor->type) {
+ TfLiteType tensor_type = tensor->type;
+ if (hybrid_op && (tensor_type == kTfLiteUInt8)) {
+ // For legacy reason, UINT8 weights in hybrid operators are actually INT8
+ // values and should be interpreted as such.
+ tensor_type = kTfLiteInt8;
+ }
+ switch (tensor_type) {
case kTfLiteNoType:
// Tensors added during initialization of Ops don't have a type yet and
// should not be registered with the NNAPI.
- *ann_tensor_index_out = -1;
+ indices->push_back(-1);
return kTfLiteOk;
case kTfLiteFloat32:
nn_type = ANEURALNETWORKS_TENSOR_FLOAT32;
@@ -273,6 +424,10 @@
scale = 1;
}
break;
+ case kTfLiteInt8:
+ nn_type = ANEURALNETWORKS_TENSOR_QUANT8_SYMM;
+ scale = tensor->params.scale;
+ break;
case kTfLiteInt32:
nn_type = ANEURALNETWORKS_TENSOR_INT32;
scale = tensor->params.scale;
@@ -298,53 +453,7 @@
nn_model_, ann_tensor_index, tensor->data.raw, tensor->bytes));
}
- *ann_tensor_index_out = ann_tensor_index;
- return kTfLiteOk;
- }
-
- // Finish emitting the op (of type `type`) into the NN API.
- TfLiteStatus FinalizeAddOperation(ANeuralNetworksOperationType type) {
- // Actually add a NN API operation
- RETURN_TFLITE_ERROR_IF_NN_ERROR(
- context_,
- nnapi_->ANeuralNetworksModel_addOperation(
- nn_model_, type, static_cast<uint32_t>(augmented_inputs_.size()),
- augmented_inputs_.data(),
- static_cast<uint32_t>(augmented_outputs_.size()),
- augmented_outputs_.data()));
- augmented_inputs_.clear();
- augmented_outputs_.clear();
- return kTfLiteOk;
- }
-
- private:
- template <typename T>
- TfLiteStatus AddScalarOperand(T value, int32_t nn_type) {
- ANeuralNetworksOperandType operand_type{.type = nn_type};
- RETURN_TFLITE_ERROR_IF_NN_ERROR(
- context_,
- nnapi_->ANeuralNetworksModel_addOperand(nn_model_, &operand_type));
- int ann_operand = operand_mapping_->add_new_non_tensor_operand();
- RETURN_TFLITE_ERROR_IF_NN_ERROR(
- context_, nnapi_->ANeuralNetworksModel_setOperandValue(
- nn_model_, ann_operand, &value, sizeof(T)));
- augmented_inputs_.push_back(ann_operand);
- return kTfLiteOk;
- }
-
- template <typename T>
- TfLiteStatus AddVectorOperand(const T* values, uint32_t num_values,
- int32_t nn_type) {
- ANeuralNetworksOperandType operand_type{
- .type = nn_type, .dimensionCount = 1, .dimensions = &num_values};
- RETURN_TFLITE_ERROR_IF_NN_ERROR(
- context_,
- nnapi_->ANeuralNetworksModel_addOperand(nn_model_, &operand_type));
- int ann_operand = operand_mapping_->add_new_non_tensor_operand();
- RETURN_TFLITE_ERROR_IF_NN_ERROR(
- context_, nnapi_->ANeuralNetworksModel_setOperandValue(
- nn_model_, ann_operand, values, sizeof(T) * num_values));
- augmented_inputs_.push_back(ann_operand);
+ indices->push_back(ann_tensor_index);
return kTfLiteOk;
}
@@ -355,7 +464,13 @@
TfLiteContext* const context_;
// Tracks relationship between indices.
- OperandMapping* operand_mapping_;
+ OperandMapping* const operand_mapping_;
+
+ // Keeps mapping of ANN quantized tensor and float data type to equivalent
+ // dequantized ANN tensor. For example, tensor #4 (UINT8) + FLOAT32 could map
+ // to tensor #10 (FLOAT32) because a DEQUANTIZE operator was added to convert
+ // tensor #4 to a FLOAT32 tensor.
+ DequantizeMapping* const dequantize_mapping_;
// The NNAPI model.
ANeuralNetworksModel* const nn_model_;
@@ -394,8 +509,9 @@
// Return a function that knows how to translate a node into its operands
// when called. You can use this function to see if a node is supported
// (i.e. that MappingFn is not nullptr).
- static MappingFn Map(TfLiteContext* context, int builtin_code, int version,
- int android_sdk_version, TfLiteNode* node) {
+ static MappingFn Map(const TfLiteContext* context, int builtin_code,
+ int version, int android_sdk_version,
+ const TfLiteNode* node) {
switch (builtin_code) {
case kTfLiteBuiltinAdd:
if (version == 1) {
@@ -451,6 +567,11 @@
break;
case kTfLiteBuiltinConv2d:
if (version == 1) {
+ if (android_sdk_version < kMinSdkVersionForNNAPI12 &&
+ IsHybridOperator(context, builtin_code, node)) {
+ // Hybrid operators not supported before NNAPI 1.2.
+ return nullptr;
+ }
auto builtin =
reinterpret_cast<TfLiteConvParams*>(node->builtin_data);
if (builtin->dilation_width_factor != 1 ||
@@ -488,6 +609,11 @@
break;
case kTfLiteBuiltinFullyConnected:
if (version == 1) {
+ if (android_sdk_version < kMinSdkVersionForNNAPI12 &&
+ IsHybridOperator(context, builtin_code, node)) {
+ // Hybrid operators not supported before NNAPI 1.2.
+ return nullptr;
+ }
return [](const NNAPIOpMappingArgs& mapping_args)
-> ANeuralNetworksOperationType {
auto builtin = reinterpret_cast<TfLiteFullyConnectedParams*>(
@@ -720,7 +846,9 @@
break;
case kTfLiteBuiltinSvdf:
// NNAPI only support float32 weights.
+ // Only delegate to NNAPI 1.1, as SVDF does not support rank > 1 on 1.0.
if (version == 1 && node->inputs->size == 5 &&
+ android_sdk_version >= kMinSdkVersionForNNAPI11 &&
context->tensors[node->inputs->data[/*kWeightsFeatureTensor*/ 1]]
.type == kTfLiteFloat32) {
return [](const NNAPIOpMappingArgs& mapping_args)
@@ -746,8 +874,11 @@
break;
case kTfLiteBuiltinLstm:
// NNAPI only support float32 weights.
+ // Only delegate to NNAPI 1.1, as 1.0 has a bug for optional tensors
+ // which would affect LSTM.
// TODO(miaowang): add loggings to indicate why the op is rejected.
if (version == 1 && node->inputs->size == 20 &&
+ android_sdk_version >= kMinSdkVersionForNNAPI11 &&
context->tensors[node->inputs
->data[/*kInputToOutputWeightsTensor*/ 4]]
.type == kTfLiteFloat32) {
@@ -952,18 +1083,69 @@
std::unique_ptr<NNMemory> nn_input_memory_;
std::unique_ptr<NNMemory> nn_output_memory_;
+ void AddDequantizeOperatorsWhereNeeded(const TfLiteContext* context,
+ int builtin_code,
+ const TfLiteNode* node,
+ NNAPIOpBuilder* builder) {
+ // Depending on the operator and the input data format, Dequantize
+ // operators may need to be added. For example when the input is
+ // floating-point but weights are quantized then the weights will first be
+ // dequantized to the same format as the input before being passed to the
+ // operator.
+
+ // The tensor determining whether the inputs should be floating-point.
+ int input_tensor_index = -1;
+ std::vector<int> inputs_to_potentially_dequantize;
+
+ switch (builtin_code) {
+ case kTfLiteBuiltinConv2d:
+ case kTfLiteBuiltinFullyConnected: {
+ input_tensor_index = 0;
+ // Weights and bias are inputs #1 and #2 respectively and may require
+ // dequantization.
+ inputs_to_potentially_dequantize = {1, 2};
+ break;
+ }
+ default:
+ return;
+ }
+
+ int tensor_id = node->inputs->data[input_tensor_index];
+ if (tensor_id < 0) return;
+
+ // Nothing to do if the input is not floating-point.
+ if (!IsFloat(context->tensors[tensor_id].type)) return;
+
+ for (int i : inputs_to_potentially_dequantize) {
+ tensor_id = node->inputs->data[i];
+ if (tensor_id < 0) continue; // Ignore optional input.
+
+ const TfLiteType type = context->tensors[tensor_id].type;
+ // Nothing to do for this tensor if it's not quantized.
+ if (type != kTfLiteUInt8) continue;
+
+ // Insert Dequantize operator if it hasn't been done already and change
+ // the node's input accordingly.
+ builder->AddDequantize(i, node->inputs->data[i], type);
+ }
+ }
+
TfLiteStatus AddOpsAndTensors(TfLiteContext* context) {
- // The operand builder allows creating a single op. We create it at this
- // reduced power position rather than in the for loop to avoid reallocating
- // the vectors.
- NNAPIOpBuilder builder(nnapi_, context, &operand_mapping_, nn_model_.get());
- // Add Tensors
- // allocate outside to avoid realloc
+ DequantizeMapping dequantize_mapping;
+ // The operand builder allows creating a single op. It is created outside
+ // the for loop to avoid reallocating the vectors.
+ NNAPIOpBuilder builder(nnapi_, context, &operand_mapping_,
+ &dequantize_mapping, nn_model_.get());
+ // Add Tensors.
for (auto node_index : nodes_) {
// Obtain the op and registration.
TfLiteNode* node;
TfLiteRegistration* reg;
- context->GetNodeAndRegistration(context, node_index, &node, ®);
+ TF_LITE_ENSURE_STATUS(
+ context->GetNodeAndRegistration(context, node_index, &node, ®));
+
+ const bool hybrid_op = IsHybridOperator(context, reg->builtin_code, node);
+
// Map inputs to NN API tensor indices.
for (auto input_index : TfLiteIntArrayView(node->inputs)) {
if (input_index == kOptionalTensor &&
@@ -975,7 +1157,7 @@
// tensor when supported by NNAPI.
TF_LITE_ENSURE_STATUS(builder.AddVectorFloat32Operand(nullptr, 0));
} else {
- TF_LITE_ENSURE_STATUS(builder.AddTensorInput(input_index));
+ TF_LITE_ENSURE_STATUS(builder.AddTensorInput(input_index, hybrid_op));
}
}
// Get op type and operands
@@ -988,6 +1170,11 @@
TF_LITE_ENSURE_STATUS(builder.AddTensorOutput(output_index));
}
+ // Dequantize operators may have to be added in case inputs are to be
+ // floating-point.
+ AddDequantizeOperatorsWhereNeeded(context, reg->builtin_code, node,
+ &builder);
+
builder.FinalizeAddOperation(nn_op_type);
}
return kTfLiteOk;
@@ -1021,7 +1208,7 @@
total_output_byte_size += context->tensors[i].bytes;
}
- // Add state output tensors as model inputs
+ // Add state output tensors as model outputs.
for (int i : model_state_outputs_) {
outputs.push_back(i);
}
diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc
index 5da052e..3751238 100644
--- a/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc
+++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate_test.cc
@@ -49,6 +49,24 @@
const std::vector<int>& dims) {
return interpreter_->ResizeInputTensor(tensor_index, dims);
}
+
+ protected:
+ void SetData(int index, TensorType type, std::initializer_list<float> data) {
+ switch (type) {
+ case TensorType_FLOAT32:
+ PopulateTensor(index, data);
+ break;
+ case TensorType_INT32:
+ QuantizeAndPopulate<int32_t>(index, data);
+ break;
+ case TensorType_UINT8:
+ QuantizeAndPopulate<uint8_t>(index, data);
+ break;
+ default:
+ FAIL() << "Type not supported: " << type;
+ break;
+ }
+ }
};
class FloatAddOpModel : public SingleOpModelWithNNAPI {
@@ -225,14 +243,15 @@
EXPECT_THAT(m.GetOutput(), ElementsAreArray({3.5, 6.5}));
}
-class BaseConvolutionOpModel : public SingleOpModelWithNNAPI {
+class ConvolutionOpModel : public SingleOpModelWithNNAPI {
public:
- BaseConvolutionOpModel(
+ ConvolutionOpModel(
const TensorData& input, const TensorData& filter,
const TensorData& output, int stride_width = 2, int stride_height = 2,
enum Padding padding = Padding_VALID,
enum ActivationFunctionType activation = ActivationFunctionType_NONE,
- int dilation_width_factor = 1, int dilation_height_factor = 1) {
+ int dilation_width_factor = 1, int dilation_height_factor = 1)
+ : input_type_(input.type), filter_type_(filter.type) {
input_ = AddInput(input);
filter_ = AddInput(filter);
@@ -249,7 +268,8 @@
}
output_ = AddOutput(output);
- if (input.type != TensorType_FLOAT32) {
+
+ if (input_type_ != TensorType_FLOAT32) {
// The following is required by quantized inference. It is the unittest's
// responsibility to make sure the output scale falls into the correct
// range.
@@ -265,56 +285,53 @@
BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)});
}
+ void SetInput(std::initializer_list<float> data) {
+ SetData(input_, input_type_, data);
+ }
+
+ void SetFilter(std::initializer_list<float> data) {
+ SetData(filter_, filter_type_, data);
+ }
+
+ void SetBias(std::initializer_list<float> data) {
+ const auto bias_type =
+ (input_type_ == TensorType_FLOAT32) ? input_type_ : TensorType_INT32;
+ SetData(bias_, bias_type, data);
+ }
+
+ std::vector<float> GetOutput() {
+ if (input_type_ == TensorType_FLOAT32) {
+ return ExtractVector<float>(output_);
+ } else {
+ return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
+ GetScale(output_), GetZeroPoint(output_));
+ }
+ }
+
+ std::vector<uint8_t> GetQuantizedOutput() {
+ if (input_type_ == TensorType_FLOAT32) {
+ return {}; // Not supported.
+ } else {
+ return ExtractVector<uint8_t>(output_);
+ }
+ }
+
protected:
int input_;
int filter_;
int bias_;
int output_;
-};
-class ConvolutionOpModel : public BaseConvolutionOpModel {
- public:
- using BaseConvolutionOpModel::BaseConvolutionOpModel;
-
- void SetFilter(std::initializer_list<float> f) { PopulateTensor(filter_, f); }
-
- void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
-
- void SetInput(std::initializer_list<float> data) {
- PopulateTensor(input_, data);
- }
- std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
-};
-
-class QuantizedConvolutionOpModel : public BaseConvolutionOpModel {
- public:
- using BaseConvolutionOpModel::BaseConvolutionOpModel;
-
- void SetInput(std::initializer_list<float> data) {
- QuantizeAndPopulate<uint8_t>(input_, data);
- }
-
- void SetFilter(std::initializer_list<float> data) {
- QuantizeAndPopulate<uint8_t>(filter_, data);
- }
-
- void SetBias(std::initializer_list<float> data) {
- QuantizeAndPopulate<int32_t>(bias_, data);
- }
-
- std::vector<uint8_t> GetOutput() { return ExtractVector<uint8_t>(output_); }
- std::vector<float> GetDequantizedOutput() {
- return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
- GetScale(output_), GetZeroPoint(output_));
- }
+ const TensorType input_type_;
+ const TensorType filter_type_;
};
// In this tests we set the input and output scales so that the results
// match exactly the 'non-quantized' version.
-TEST(NNAPIDelegate, SimpleTestQuantized) {
- QuantizedConvolutionOpModel m({TensorType_UINT8, {2, 2, 4, 1}, -63.5, 64},
- {TensorType_UINT8, {3, 2, 2, 1}, -63.5, 64},
- {TensorType_UINT8, {}, -127, 128});
+TEST(ConvolutionOpTest, SimpleTestQuantized) {
+ ConvolutionOpModel m({TensorType_UINT8, {2, 2, 4, 1}, -63.5, 64},
+ {TensorType_UINT8, {3, 2, 2, 1}, -63.5, 64},
+ {TensorType_UINT8, {}, -127, 128});
m.SetInput({
// First batch
1, 1, 1, 1, // row = 1
@@ -332,25 +349,55 @@
m.Invoke();
- EXPECT_THAT(m.GetDequantizedOutput(),
- ElementsAreArray(ArrayFloatNear(
- {
- 18, 2, 5, // first batch, left
- 18, 2, 5, // first batch, right
- 17, 4, 3, // second batch, left
- 37, 4, 3, // second batch, right
- },
- 1e-5)));
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
+ {
+ 18, 2, 5, // first batch, left
+ 18, 2, 5, // first batch, right
+ 17, 4, 3, // second batch, left
+ 37, 4, 3, // second batch, right
+ },
+ 1e-5)));
// For good measure, let's also verify the quantized values:
- EXPECT_THAT(m.GetOutput(), ElementsAreArray({
- 145, 129, 132, //
- 145, 129, 132, //
- 144, 131, 130, //
- 164, 131, 130, //
- }));
+ EXPECT_THAT(m.GetQuantizedOutput(), ElementsAreArray({
+ 145, 129, 132, //
+ 145, 129, 132, //
+ 144, 131, 130, //
+ 164, 131, 130, //
+ }));
}
-TEST(NNAPIDelegate, Conv2DWithNoActivation) {
+TEST(ConvolutionOpTest, FloatInputQuantizedWeights) {
+ ConvolutionOpModel m({TensorType_FLOAT32, {2, 2, 4, 1}},
+ {TensorType_UINT8, {3, 2, 2, 1}, 0, 64},
+ {TensorType_FLOAT32, {}});
+ m.SetInput({
+ // First batch
+ 1, 1, 1, 2, // row = 1
+ 2, 2, 2, 1, // row = 2
+ // Second batch
+ 1, 2, 3, 4, // row = 1
+ 1, 2, 3, 4, // row = 2
+ });
+ m.SetFilter({
+ 1, 2, 3, 4, // first 2x2 filter
+ 0, 1, 0, 1, // second 2x2 filter
+ 0, 0, 1, 1, // third 2x2 filter
+ });
+ m.SetBias({1, 2, 3});
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear(
+ {
+ 18, 5, 7, // first batch, left
+ 16, 5, 6, // first batch, right
+ 17, 6, 6, // second batch, left
+ 37, 10, 10, // second batch, right
+ },
+ 0.2)));
+}
+
+TEST(ConvolutionOpTest, NoActivation) {
ConvolutionOpModel m({TensorType_FLOAT32, {2, 2, 4, 1}},
{TensorType_FLOAT32, {3, 2, 2, 1}},
{TensorType_FLOAT32, {}});
@@ -458,56 +505,48 @@
}));
}
-class FloatFullyConnectedOpModel : public SingleOpModelWithNNAPI {
+class FullyConnectedOpModel : public SingleOpModelWithNNAPI {
public:
- FloatFullyConnectedOpModel(int units, int batches, const TensorData& input,
- const TensorData& output = {TensorType_FLOAT32})
- : batches_(batches), units_(units) {
- int total_input_size = 1;
- for (int i = 0; i < input.shape.size(); ++i) {
- total_input_size *= input.shape[i];
- }
- input_size_ = total_input_size / batches_;
-
+ FullyConnectedOpModel(
+ const TensorData& input, const TensorData& weights,
+ const TensorData& output,
+ enum ActivationFunctionType activation = ActivationFunctionType_NONE)
+ : input_type_(input.type), weights_type_(weights.type) {
input_ = AddInput(input);
- weights_ =
- AddInput({input.type, {units_, input_size_}, input.min, input.max});
+ weights_ = AddInput(weights);
+ const int units = weights.shape[0];
if (input.type == TensorType_FLOAT32) {
- bias_ = AddInput({TensorType_FLOAT32, {units_}});
+ bias_ = AddInput({TensorType_FLOAT32, {units}});
} else {
// This is a quantized version. The scale of 'bias' depends on the scales
// of input and filter. Supposedly this is correctly set during quantized
// training.
auto bias_scale = GetScale(input_) * GetScale(weights_);
- TensorData bias{TensorType_INT32, {units_}, 0, 0, bias_scale};
+ TensorData bias{TensorType_INT32, {units}, 0, 0, bias_scale};
bias_ = AddInput(bias);
}
output_ = AddOutput(output);
- SetBuiltinOp(
- BuiltinOperator_FULLY_CONNECTED, BuiltinOptions_FullyConnectedOptions,
- CreateFullyConnectedOptions(builder_, ActivationFunctionType_RELU)
- .Union());
+ SetBuiltinOp(BuiltinOperator_FULLY_CONNECTED,
+ BuiltinOptions_FullyConnectedOptions,
+ CreateFullyConnectedOptions(builder_, activation).Union());
BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)});
}
- int input_size() { return input_size_; }
- int num_units() { return units_; }
- int num_batches() { return batches_; }
-
- void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); }
-
- void SetWeights(std::initializer_list<float> f) {
- PopulateTensor(weights_, f);
- }
-
void SetInput(std::initializer_list<float> data) {
- PopulateTensor(input_, data);
+ SetData(input_, input_type_, data);
}
- void SetInput(int offset, float* begin, float* end) {
- PopulateTensor(input_, offset, begin, end);
+
+ void SetWeights(std::initializer_list<float> data) {
+ SetData(weights_, weights_type_, data);
+ }
+
+ void SetBias(std::initializer_list<float> data) {
+ const auto bias_type =
+ (input_type_ == TensorType_FLOAT32) ? input_type_ : TensorType_INT32;
+ SetData(bias_, bias_type, data);
}
std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
@@ -518,14 +557,14 @@
int bias_;
int output_;
- int batches_;
- int units_;
- int input_size_;
+ const TensorType input_type_;
+ const TensorType weights_type_;
};
-TEST(NNAPIDelegate, FullyConnectedSimpleTest) {
- FloatFullyConnectedOpModel m(/*units=*/3, /*batches=*/2,
- /*input=*/{TensorType_FLOAT32, {2, 10}});
+TEST(FullyConnectedOpTest, SimpleTest) {
+ FullyConnectedOpModel m(/*input=*/{TensorType_FLOAT32, {2, 10}},
+ /*weights=*/{TensorType_FLOAT32, {3, 10}},
+ /*output=*/{TensorType_FLOAT32});
m.SetWeights({
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
@@ -543,6 +582,28 @@
EXPECT_THAT(m.GetOutput(), ElementsAre(24, 25, 26, 58, 59, 60));
}
+TEST(FullyConnectedOpTest, FloatInputQuantizedWeights) {
+ FullyConnectedOpModel m(/*input=*/{TensorType_FLOAT32, {2, 10}},
+ /*weights=*/{TensorType_UINT8, {3, 10}, 0, 64},
+ /*output=*/{TensorType_FLOAT32});
+ m.SetWeights({
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 0
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
+ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // u = 1
+ });
+ m.SetBias({1, 2, 3});
+
+ m.SetInput({
+ 1, 2, 3, 4, 5, 6, 7, 8, -9, -10, // b = 0
+ 1, 2, 3, 4, 5, 6, 7, -8, 9, -10, // b = 1
+ });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray(ArrayFloatNear({24, 25, 26, 58, 59, 60}, 1.3)));
+}
+
class SoftmaxOpModel : public SingleOpModelWithNNAPI {
public:
SoftmaxOpModel(int batches, int size, float beta)
diff --git a/tensorflow/lite/examples/android/app/build.gradle b/tensorflow/lite/examples/android/app/build.gradle
index b372afa..7b34525 100644
--- a/tensorflow/lite/examples/android/app/build.gradle
+++ b/tensorflow/lite/examples/android/app/build.gradle
@@ -1,5 +1,13 @@
apply plugin: 'com.android.application'
+// import DownloadModels task
+project.ext.ASSET_DIR = projectDir.toString() + '/src/main/assets'
+project.ext.TMP_DIR = project.buildDir.toString() + '/downloads'
+
+// Download default models; if you wish to use your own models then
+// place them in the "assets" directory and comment out this line.
+apply from: "download-models.gradle"
+
android {
compileSdkVersion 26
buildToolsVersion '27.0.3'
@@ -36,14 +44,6 @@
}
}
-// import DownloadModels task
-project.ext.ASSET_DIR = projectDir.toString() + '/src/main/assets'
-project.ext.TMP_DIR = project.buildDir.toString() + '/downloads'
-
-// Download default models; if you wish to use your own models then
-// place them in the "assets" directory and comment out this line.
-apply from: "download-models.gradle"
-
dependencies {
implementation fileTree(dir: 'libs', include: ['*.jar'])
implementation 'org.tensorflow:tensorflow-lite:0.0.0-nightly'
diff --git a/tensorflow/lite/experimental/examples/lstm/BUILD b/tensorflow/lite/experimental/examples/lstm/BUILD
index a4950d2..827b104 100644
--- a/tensorflow/lite/experimental/examples/lstm/BUILD
+++ b/tensorflow/lite/experimental/examples/lstm/BUILD
@@ -44,9 +44,9 @@
"//tensorflow:tensorflow_py",
"//tensorflow/examples/tutorials/mnist:input_data",
"//tensorflow/lite/python:lite",
- "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:framework",
"//tensorflow/python:platform",
- "//tensorflow/python/tools:optimize_for_inference",
+ "//tensorflow/python/tools:optimize_for_inference_lib",
"//third_party/py/numpy",
"@six_archive//:six",
],
diff --git a/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_lstm_test.py b/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_lstm_test.py
index 99f4bed..9dc8109 100644
--- a/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_lstm_test.py
+++ b/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_lstm_test.py
@@ -150,8 +150,8 @@
curr, ["INPUT_IMAGE_LITE"], ["OUTPUT_CLASS"],
[tf.float32.as_datatype_enum])
- tflite = tf.lite.toco_convert(
- curr, [tflite_input], [outputs], allow_custom_ops=False)
+ converter = tf.lite.TFLiteConverter(curr, [tflite_input], [outputs])
+ tflite = converter.convert()
interpreter = tf.lite.Interpreter(model_content=tflite)
diff --git a/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_rnn_test.py b/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_rnn_test.py
index d049c78..7a937ce 100644
--- a/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_rnn_test.py
+++ b/tensorflow/lite/experimental/examples/lstm/bidirectional_sequence_rnn_test.py
@@ -149,8 +149,8 @@
curr, ["INPUT_IMAGE_LITE"], ["OUTPUT_CLASS"],
[tf.float32.as_datatype_enum])
- tflite = tf.lite.toco_convert(
- curr, [tflite_input], [outputs], allow_custom_ops=False)
+ converter = tf.lite.TFLiteConverter(curr, [tflite_input], [outputs])
+ tflite = converter.convert()
interpreter = tf.lite.Interpreter(model_content=tflite)
diff --git a/tensorflow/lite/experimental/examples/lstm/tflite_lstm.py b/tensorflow/lite/experimental/examples/lstm/tflite_lstm.py
index 0038e54..e6d329f 100644
--- a/tensorflow/lite/experimental/examples/lstm/tflite_lstm.py
+++ b/tensorflow/lite/experimental/examples/lstm/tflite_lstm.py
@@ -22,17 +22,27 @@
import tensorflow as tf
from tensorflow.lite.python import lite
+from tensorflow.python.eager import context
+from tensorflow.python.framework import ops
from tensorflow.python.keras import activations
from tensorflow.python.keras import initializers
from tensorflow.python.layers import base as base_layer
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import rnn_cell_impl
+from tensorflow.python.ops import variable_scope as vs
+from tensorflow.python.ops.rnn import _best_effort_input_batch_size
+from tensorflow.python.ops.rnn import _dynamic_rnn_loop
+from tensorflow.python.ops.rnn import _should_cache
+from tensorflow.python.ops.rnn import _transpose_batch_time
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import nest
class TFLiteLSTMCell(rnn_cell_impl.LayerRNNCell):
@@ -394,3 +404,240 @@
}
base_config = super(TFLiteLSTMCell, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
+
+
+def dynamic_rnn(cell,
+ inputs,
+ sequence_length=None,
+ initial_state=None,
+ dtype=None,
+ parallel_iterations=None,
+ swap_memory=False,
+ time_major=True,
+ scope=None):
+ """Creates a recurrent neural network specified by RNNCell `cell`.
+
+ Performs fully dynamic unrolling of `inputs`.
+
+ Example:
+
+ ```python
+ # create a BasicRNNCell
+ rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
+
+ # 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]
+
+ # defining initial state
+ initial_state = rnn_cell.zero_state(batch_size, dtype=tf.float32)
+
+ # 'state' is a tensor of shape [batch_size, cell_state_size]
+ outputs, state = tf.nn.dynamic_rnn(rnn_cell, input_data,
+ initial_state=initial_state,
+ dtype=tf.float32)
+ ```
+
+ ```python
+ # create 2 LSTMCells
+ rnn_layers = [tf.nn.rnn_cell.LSTMCell(size) for size in [128, 256]]
+
+ # create a RNN cell composed sequentially of a number of RNNCells
+ multi_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(rnn_layers)
+
+ # 'outputs' is a tensor of shape [batch_size, max_time, 256]
+ # 'state' is a N-tuple where N is the number of LSTMCells containing a
+ # tf.contrib.rnn.LSTMStateTuple for each cell
+ outputs, state = tf.nn.dynamic_rnn(cell=multi_rnn_cell,
+ inputs=data,
+ dtype=tf.float32)
+ ```
+
+
+ Args:
+ cell: An instance of RNNCell.
+ inputs: The RNN inputs.
+ If `time_major == False` (default), this must be a `Tensor` of shape:
+ `[batch_size, max_time, ...]`, or a nested tuple of such elements.
+ If `time_major == True`, this must be a `Tensor` of shape: `[max_time,
+ batch_size, ...]`, or a nested tuple of such elements. This may also be
+ a (possibly nested) tuple of Tensors satisfying this property. The
+ first two dimensions must match across all the inputs, but otherwise the
+ ranks and other shape components may differ. In this case, input to
+ `cell` at each time-step will replicate the structure of these tuples,
+ except for the time dimension (from which the time is taken). The input
+ to `cell` at each time step will be a `Tensor` or (possibly nested)
+ tuple of Tensors each with dimensions `[batch_size, ...]`.
+ sequence_length: (optional) An int32/int64 vector sized `[batch_size]`. Used
+ to copy-through state and zero-out outputs when past a batch element's
+ sequence length. So it's more for performance than correctness.
+ initial_state: (optional) An initial state for the RNN. If `cell.state_size`
+ is an integer, this must be a `Tensor` of appropriate type and shape
+ `[batch_size, cell.state_size]`. If `cell.state_size` is a tuple, this
+ should be a tuple of tensors having shapes `[batch_size, s] for s in
+ cell.state_size`.
+ dtype: (optional) The data type for the initial state and expected output.
+ Required if initial_state is not provided or RNN state has a heterogeneous
+ dtype.
+ parallel_iterations: (Default: 32). The number of iterations to run in
+ parallel. Those operations which do not have any temporal dependency and
+ can be run in parallel, will be. This parameter trades off time for
+ space. Values >> 1 use more memory but take less time, while smaller
+ values use less memory but computations take longer.
+ swap_memory: Transparently swap the tensors produced in forward inference
+ but needed for back prop from GPU to CPU. This allows training RNNs which
+ would typically not fit on a single GPU, with very minimal (or no)
+ performance penalty.
+ time_major: The shape format of the `inputs` and `outputs` Tensors. If true,
+ these `Tensors` must be shaped `[max_time, batch_size, depth]`. If false,
+ these `Tensors` must be shaped `[batch_size, max_time, depth]`. Using
+ `time_major = True` is a bit more efficient because it avoids transposes
+ at the beginning and end of the RNN calculation. However, most TensorFlow
+ data is batch-major, so by default this function accepts input and emits
+ output in batch-major form.
+ scope: VariableScope for the created subgraph; defaults to "rnn".
+
+ Returns:
+ A pair (outputs, state) where:
+
+ outputs: The RNN output `Tensor`.
+
+ If time_major == False (default), this will be a `Tensor` shaped:
+ `[batch_size, max_time, cell.output_size]`.
+
+ If time_major == True, this will be a `Tensor` shaped:
+ `[max_time, batch_size, cell.output_size]`.
+
+ Note, if `cell.output_size` is a (possibly nested) tuple of integers
+ or `TensorShape` objects, then `outputs` will be a tuple having the
+ same structure as `cell.output_size`, containing Tensors having shapes
+ corresponding to the shape data in `cell.output_size`.
+
+ state: The final state. If `cell.state_size` is an int, this
+ will be shaped `[batch_size, cell.state_size]`. If it is a
+ `TensorShape`, this will be shaped `[batch_size] + cell.state_size`.
+ If it is a (possibly nested) tuple of ints or `TensorShape`, this will
+ be a tuple having the corresponding shapes. If cells are `LSTMCells`
+ `state` will be a tuple containing a `LSTMStateTuple` for each cell.
+
+ Raises:
+ TypeError: If `cell` is not an instance of RNNCell.
+ ValueError: If inputs is None or an empty list.
+ RuntimeError: If not using control flow v2.
+ """
+
+ # Currently only support time_major == True case.
+ assert time_major
+
+ # TODO(b/123051275): We need to check if the cells are TfLiteLSTMCells or
+ # TfLiteRNNCells.
+ rnn_cell_impl.assert_like_rnncell("cell", cell)
+
+ if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
+ raise RuntimeError("OpHint dynamic rnn only supports control flow v2.")
+
+ parent_first_child_input = [{
+ "parent_ophint_input_index": 0,
+ "first_child_ophint_input_index": 0
+ }]
+ parent_last_child_output = [{
+ "parent_output_index": 0,
+ # For LstmCell, the index is 2.
+ # For RnnCell, the index is 1.
+ # So we use -1 meaning it's the last one.
+ "child_output_index": -1
+ }]
+ internal_children_input_output = [{
+ "child_input_index": 0,
+ # For LstmCell, the index is 2.
+ # For RnnCell, the index is 1.
+ # So we use -1 meaning it's the last one.
+ "child_output_index": -1
+ }]
+ inputs_outputs_mappings = {
+ "parent_first_child_input": parent_first_child_input,
+ "parent_last_child_output": parent_last_child_output,
+ "internal_children_input_output": internal_children_input_output
+ }
+ tflite_wrapper = lite.OpHint(
+ "TfLiteDynamicRnn",
+ level=2,
+ children_inputs_mappings=inputs_outputs_mappings)
+ with vs.variable_scope(scope or "rnn") as varscope:
+ # Create a new scope in which the caching device is either
+ # determined by the parent scope, or is set to place the cached
+ # Variable using the same placement as for the rest of the RNN.
+ if _should_cache():
+ if varscope.caching_device is None:
+ varscope.set_caching_device(lambda op: op.device)
+
+ inputs = tflite_wrapper.add_input(inputs, name="input", index_override=0)
+
+ # By default, time_major==False and inputs are batch-major: shaped
+ # [batch, time, depth]
+ # For internal calculations, we transpose to [time, batch, depth]
+ flat_input = nest.flatten(inputs)
+
+ if not time_major:
+ # (batch, time, depth) => (time, batch, depth)
+ flat_input = [ops.convert_to_tensor(input_) for input_ in flat_input]
+ flat_input = tuple(_transpose_batch_time(input_) for input_ in flat_input)
+
+ parallel_iterations = parallel_iterations or 32
+ if sequence_length is not None:
+ sequence_length = math_ops.to_int32(sequence_length)
+ if sequence_length.get_shape().rank not in (None, 1):
+ raise ValueError(
+ "sequence_length must be a vector of length batch_size, "
+ "but saw shape: %s" % sequence_length.get_shape())
+ sequence_length = array_ops.identity( # Just to find it in the graph.
+ sequence_length,
+ name="sequence_length")
+
+ batch_size = _best_effort_input_batch_size(flat_input)
+
+ if initial_state is not None:
+ state = initial_state
+ else:
+ if not dtype:
+ raise ValueError("If there is no initial_state, you must give a dtype.")
+ if getattr(cell, "get_initial_state", None) is not None:
+ state = cell.get_initial_state(
+ inputs=None, batch_size=batch_size, dtype=dtype)
+ else:
+ state = cell.zero_state(batch_size, dtype)
+
+ def _assert_has_shape(x, shape):
+ x_shape = array_ops.shape(x)
+ packed_shape = array_ops.stack(shape)
+ return control_flow_ops.Assert(
+ math_ops.reduce_all(math_ops.equal(x_shape, packed_shape)), [
+ "Expected shape for Tensor %s is " % x.name, packed_shape,
+ " but saw shape: ", x_shape
+ ])
+
+ if not context.executing_eagerly() and sequence_length is not None:
+ # Perform some shape validation
+ with ops.control_dependencies(
+ [_assert_has_shape(sequence_length, [batch_size])]):
+ sequence_length = array_ops.identity(
+ sequence_length, name="CheckSeqLen")
+
+ inputs = nest.pack_sequence_as(structure=inputs, flat_sequence=flat_input)
+
+ outputs, final_state = _dynamic_rnn_loop(
+ cell,
+ inputs,
+ state,
+ parallel_iterations=parallel_iterations,
+ swap_memory=swap_memory,
+ sequence_length=sequence_length,
+ dtype=dtype)
+
+ # Outputs of _dynamic_rnn_loop are always shaped [time, batch, depth].
+ # If we are performing batch-major calculations, transpose output back
+ # to shape [batch, time, depth]
+ if not time_major:
+ # (time, batch, depth) => (batch, time, depth)
+ outputs = nest.map_structure(_transpose_batch_time, outputs)
+ outputs = tflite_wrapper.add_output(outputs, name="outputs")
+
+ return outputs, final_state
diff --git a/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py b/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py
index eeb48d1..ed02d6c 100644
--- a/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py
+++ b/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_lstm_test.py
@@ -20,12 +20,14 @@
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
+from tensorflow.lite.experimental.examples.lstm.tflite_lstm import dynamic_rnn
from tensorflow.lite.experimental.examples.lstm.tflite_lstm import TFLiteLSTMCell
from tensorflow.lite.python.op_hint import convert_op_hints_to_stubs
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
from tensorflow.python.tools import optimize_for_inference_lib
+
# Number of steps to train model.
TRAIN_STEPS = 1
@@ -67,7 +69,7 @@
TFLiteLSTMCell(self.num_units, forget_bias=0, name="rnn4")
])
- def buildModel(self, lstm_layer, is_dynamic_rnn, is_train):
+ def buildModel(self, lstm_layer, is_dynamic_rnn):
# Weights and biases for output softmax layer.
out_weights = tf.Variable(
tf.random_normal([self.num_units, self.n_classes]))
@@ -77,16 +79,11 @@
x = tf.placeholder(
"float", [None, self.time_steps, self.n_input], name="INPUT_IMAGE")
- # For dynamic_rnn, train with dynamic_rnn and inference with static_rnn.
# x is shaped [batch_size,time_steps,num_inputs]
if is_dynamic_rnn:
- if is_train:
- lstm_input = x
- outputs, _ = tf.nn.dynamic_rnn(lstm_layer, lstm_input, dtype="float32")
- outputs = tf.unstack(outputs, axis=1)
- else:
- lstm_input = tf.unstack(x, self.time_steps, 1)
- outputs, _ = tf.nn.static_rnn(lstm_layer, lstm_input, dtype="float32")
+ lstm_input = tf.transpose(x, perm=[1, 0, 2])
+ outputs, _ = dynamic_rnn(lstm_layer, lstm_input, dtype="float32")
+ outputs = tf.unstack(outputs, axis=0)
else:
lstm_input = tf.unstack(x, self.time_steps, 1)
outputs, _ = tf.nn.static_rnn(lstm_layer, lstm_input, dtype="float32")
@@ -126,8 +123,7 @@
# Reset the graph.
tf.reset_default_graph()
- x, prediction, output_class = self.buildModel(
- lstm_layer, is_dynamic_rnn, is_train=False)
+ x, prediction, output_class = self.buildModel(lstm_layer, is_dynamic_rnn)
new_sess = tf.Session(config=CONFIG)
saver = tf.train.Saver()
@@ -157,8 +153,8 @@
curr, ["INPUT_IMAGE_LITE"], ["OUTPUT_CLASS"],
[tf.float32.as_datatype_enum])
- tflite = tf.lite.toco_convert(
- curr, [tflite_input], [outputs], allow_custom_ops=False)
+ converter = tf.lite.TFLiteConverter(curr, [tflite_input], [outputs])
+ tflite = converter.convert()
interpreter = tf.lite.Interpreter(model_content=tflite)
try:
@@ -179,7 +175,7 @@
sess = tf.Session(config=CONFIG)
x, prediction, output_class = self.buildModel(
- self.buildLstmLayer(), is_dynamic_rnn=False, is_train=True)
+ self.buildLstmLayer(), is_dynamic_rnn=False)
self.trainModel(x, prediction, output_class, sess)
saver = tf.train.Saver()
@@ -192,26 +188,15 @@
result = self.tfliteInvoke(frozen_graph, test_inputs, output_class)
self.assertTrue(np.allclose(expected_output, result, rtol=1e-6, atol=1e-2))
+ @test_util.enable_control_flow_v2
def testDynamicRnnMultiRnnCell(self):
sess = tf.Session(config=CONFIG)
x, prediction, output_class = self.buildModel(
- self.buildLstmLayer(), is_dynamic_rnn=True, is_train=True)
+ self.buildLstmLayer(), is_dynamic_rnn=True)
self.trainModel(x, prediction, output_class, sess)
- # Since we don't yet support OpHints for dynamic, we will load the model
- # back in as a static model. This requires the variables to have the same
- # names as if they were trained as a static. Thus, we get rid of while/rnn
- # names.
- variables_to_save = {}
- for i in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
- op_name = i.name
- if op_name.startswith("while/rnn/"):
- op_name = op_name.split("while/rnn/")[1]
- if op_name.endswith(":0"):
- op_name = op_name.split(":0")[0]
- variables_to_save[op_name] = i
- saver = tf.train.Saver(variables_to_save)
+ saver = tf.train.Saver()
x, prediction, output_class, new_sess = self.saveAndRestoreModel(
self.buildLstmLayer(), sess, saver, is_dynamic_rnn=True)
diff --git a/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_rnn_test.py b/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_rnn_test.py
index 6f9e2dd..af7ac4c 100644
--- a/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_rnn_test.py
+++ b/tensorflow/lite/experimental/examples/lstm/unidirectional_sequence_rnn_test.py
@@ -160,8 +160,8 @@
curr, ["INPUT_IMAGE_LITE"], ["OUTPUT_CLASS"],
[tf.float32.as_datatype_enum])
- tflite = tf.lite.toco_convert(
- curr, [tflite_input], [outputs], allow_custom_ops=False)
+ converter = tf.lite.TFLiteConverter(curr, [tflite_input], [outputs])
+ tflite = converter.convert()
interpreter = tf.lite.Interpreter(model_content=tflite)
interpreter.allocate_tensors()
diff --git a/tensorflow/lite/experimental/micro/ecm3531/debug_log.cc b/tensorflow/lite/experimental/micro/ecm3531/debug_log.cc
new file mode 100644
index 0000000..4d96196
--- /dev/null
+++ b/tensorflow/lite/experimental/micro/ecm3531/debug_log.cc
@@ -0,0 +1,20 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/experimental/micro/debug_log.h"
+
+#include "eta_csp_io.h"
+
+extern "C" void DebugLog(const char* s) { EtaCspIoPrintf("%s", s); }
diff --git a/tensorflow/lite/experimental/micro/examples/micro_speech/CMSIS/Makefile.inc b/tensorflow/lite/experimental/micro/examples/micro_speech/CMSIS/Makefile.inc
index 3d56051..2d0deb0 100644
--- a/tensorflow/lite/experimental/micro/examples/micro_speech/CMSIS/Makefile.inc
+++ b/tensorflow/lite/experimental/micro/examples/micro_speech/CMSIS/Makefile.inc
@@ -5,6 +5,11 @@
-isystem$(MAKEFILE_DIR)/downloads/cmsis/CMSIS/DSP/Include/ \
-I$(MAKEFILE_DIR)/downloads/CMSIS_ext/
+ GENERATED_PROJECT_INCLUDES += \
+ -isystemthird_party/cmsis/CMSIS/Core/Include/ \
+ -isystemthird_party/cmsis/CMSIS/DSP/Include/ \
+ -Ithird_party/CMSIS_ext/
+
CMSIS_PREPROCESSOR_SRCS := \
tensorflow/lite/experimental/micro/examples/micro_speech/CMSIS/hanning.cc \
tensorflow/lite/experimental/micro/examples/micro_speech/CMSIS/sin_1k.cc \
@@ -24,20 +29,26 @@
MICRO_SPEECH_HDRS += $(CMSIS_PREPROCESSOR_HDRS)
THIRD_PARTY_CC_SRCS += \
- third_party/CMSIS_ext/arm_cmplx_mag_squared_q10p6.c \
- third_party/cmsis/CMSIS/DSP/Source/BasicMathFunctions/arm_mult_q15.c \
- third_party/cmsis/CMSIS/DSP/Source/TransformFunctions/arm_rfft_init_q15.c \
- third_party/cmsis/CMSIS/DSP/Source/TransformFunctions/arm_rfft_q15.c \
- third_party/cmsis/CMSIS/DSP/Source/TransformFunctions/arm_cfft_q15.c \
- third_party/cmsis/CMSIS/DSP/Source/TransformFunctions/arm_cfft_radix4_q15.c \
- third_party/cmsis/CMSIS/DSP/Source/TransformFunctions/arm_bitreversal2.S \
- third_party/cmsis/CMSIS/DSP/Source/CommonTables/arm_const_structs.c \
- third_party/cmsis/CMSIS/DSP/Source/CommonTables/arm_common_tables.c \
- third_party/cmsis/CMSIS/DSP/Source/StatisticsFunctions/arm_mean_q15.c \
- third_party/cmsis/CMSIS/DSP/Source/StatisticsFunctions/arm_max_q7.c
+ $(MAKEFILE_DIR)/downloads/CMSIS_ext/arm_cmplx_mag_squared_q10p6.c \
+ $(MAKEFILE_DIR)/downloads/cmsis/CMSIS/DSP/Source/BasicMathFunctions/arm_mult_q15.c \
+ $(MAKEFILE_DIR)/downloads/cmsis/CMSIS/DSP/Source/TransformFunctions/arm_bitreversal.c \
+ $(MAKEFILE_DIR)/downloads/cmsis/CMSIS/DSP/Source/TransformFunctions/arm_rfft_init_q15.c \
+ $(MAKEFILE_DIR)/downloads/cmsis/CMSIS/DSP/Source/TransformFunctions/arm_rfft_q15.c \
+ $(MAKEFILE_DIR)/downloads/cmsis/CMSIS/DSP/Source/TransformFunctions/arm_cfft_q15.c \
+ $(MAKEFILE_DIR)/downloads/cmsis/CMSIS/DSP/Source/TransformFunctions/arm_cfft_radix4_q15.c \
+ $(MAKEFILE_DIR)/downloads/cmsis/CMSIS/DSP/Source/TransformFunctions/arm_bitreversal2.S \
+ $(MAKEFILE_DIR)/downloads/cmsis/CMSIS/DSP/Source/CommonTables/arm_const_structs.c \
+ $(MAKEFILE_DIR)/downloads/cmsis/CMSIS/DSP/Source/CommonTables/arm_common_tables.c \
+ $(MAKEFILE_DIR)/downloads/cmsis/CMSIS/DSP/Source/StatisticsFunctions/arm_mean_q15.c \
+ $(MAKEFILE_DIR)/downloads/cmsis/CMSIS/DSP/Source/StatisticsFunctions/arm_max_q7.c
THIRD_PARTY_CC_HDRS += \
+ third_party/cmsis/CMSIS/Core/Include/cmsis_compiler.h \
+ third_party/cmsis/CMSIS/Core/Include/cmsis_gcc.h \
+ third_party/cmsis/CMSIS/Core/Include/cmsis_version.h \
+ third_party/cmsis/CMSIS/Core/Include/core_cm3.h \
third_party/cmsis/CMSIS/DSP/Include/arm_common_tables.h \
- third_party/cmsis/CMSIS/DSP/Include/arm_const_structs.h
+ third_party/cmsis/CMSIS/DSP/Include/arm_const_structs.h \
+ third_party/cmsis/CMSIS/DSP/Include/arm_math.h
endif
diff --git a/tensorflow/lite/experimental/micro/tools/make/Makefile b/tensorflow/lite/experimental/micro/tools/make/Makefile
index 9c62d8a..3457faf 100644
--- a/tensorflow/lite/experimental/micro/tools/make/Makefile
+++ b/tensorflow/lite/experimental/micro/tools/make/Makefile
@@ -30,13 +30,22 @@
# STM32F746NG board, using the CMSIS library's implementations where possible.
ALL_TAGS := $(TAGS) $(TARGET)
+# This is obviously horrible. We need to generate these 3 versions of the
+# include directories from one source.
INCLUDES := \
-I. \
-I$(MAKEFILE_DIR)/downloads/ \
-I$(MAKEFILE_DIR)/downloads/gemmlowp \
-I$(MAKEFILE_DIR)/downloads/flatbuffers/include
-# These are the include paths added to any generated project file.
+# Same list of paths, but now relative to the generated project files.
+GENERATED_PROJECT_INCLUDES := \
+-I. \
+-I./third_party/gemmlowp \
+-I./third_party/flatbuffers/include
+
+# Same list of paths, but now in the format the generate_keil_project.py
+# script expects them.
PROJECT_INCLUDES := \
. \
third_party/gemmlowp \
@@ -159,6 +168,9 @@
MICROLITE_LIB_OBJS := $(addprefix $(OBJDIR), \
$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(MICROLITE_CC_SRCS))))
+MICROLITE_LIB_OBJS += $(addprefix $(OBJDIR), \
+$(patsubst %.S,%.o,$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(THIRD_PARTY_CC_SRCS)))))
+
# For normal manually-created TensorFlow C++ source files.
$(OBJDIR)%.o: %.cc
@mkdir -p $(dir $@)
diff --git a/tensorflow/lite/experimental/micro/tools/make/helper_functions.inc b/tensorflow/lite/experimental/micro/tools/make/helper_functions.inc
index 4d2465b..28f7618 100644
--- a/tensorflow/lite/experimental/micro/tools/make/helper_functions.inc
+++ b/tensorflow/lite/experimental/micro/tools/make/helper_functions.inc
@@ -52,6 +52,7 @@
# 5 - List of C/C++ header files needed to build the target.
# 6 - Linker flags required.
# 7 - C++ compilation flags needed.
+# 8 - C compilation flags needed.
# Calling eval on the output will create a <Name>_makefile target that you
# can invoke to create the standalone project.
define generate_project
@@ -68,7 +69,8 @@
sed -E 's#\%\{SRCS\}\%#$(4)#g' $$< | \
sed -E 's#\%\{EXECUTABLE\}\%#$(3)#g' | \
sed -E 's#\%\{LINKER_FLAGS\}\%#$(6)#g' | \
- sed -E 's#\%\{CXX_FLAGS\}\%#$(7)#g' > $$@
+ sed -E 's#\%\{CXX_FLAGS\}\%#$(7)#g' | \
+ sed -E 's#\%\{CC_FLAGS\}\%#$(8)#g' > $$@
$(PRJDIR)$(3)/$(1)/keil_project.uvprojx: tensorflow/lite/experimental/micro/tools/make/templates/keil_project.uvprojx.tpl
@mkdir -p $$(dir $$@)
@@ -89,9 +91,9 @@
# Calling eval on the output will create targets that you can invoke to
# generate the standalone project.
define generate_microlite_projects
-$(call generate_project,make,$(MAKE_PROJECT_FILES),$(1),$(MICROLITE_CC_SRCS) $(THIRD_PARTY_CC_SRCS) $(2),$(MICROLITE_CC_HDRS) $(THIRD_PARTY_CC_HDRS) $(MICROLITE_TEST_HDRS) $(3),$(MICROLITE_LIBS),$(CXXFLAGS))
-$(call generate_project,mbed,$(MBED_PROJECT_FILES),$(1),$(MICROLITE_CC_SRCS) $(THIRD_PARTY_CC_SRCS) $(2),$(MICROLITE_CC_HDRS) $(THIRD_PARTY_CC_HDRS) $(MICROLITE_TEST_HDRS) $(3),$(MICROLITE_LIBS),$(CXXFLAGS))
-$(call generate_project,keil,$(KEIL_PROJECT_FILES),$(1),$(MICROLITE_CC_SRCS) $(THIRD_PARTY_CC_SRCS) $(2),$(MICROLITE_CC_HDRS) $(THIRD_PARTY_CC_HDRS) $(MICROLITE_TEST_HDRS) $(3),$(MICROLITE_LIBS),$(CXXFLAGS))
+$(call generate_project,make,$(MAKE_PROJECT_FILES),$(1),$(MICROLITE_CC_SRCS) $(THIRD_PARTY_CC_SRCS) $(2),$(MICROLITE_CC_HDRS) $(THIRD_PARTY_CC_HDRS) $(MICROLITE_TEST_HDRS) $(3),$(LDFLAGS) $(MICROLITE_LIBS),$(CXXFLAGS) $(GENERATED_PROJECT_INCLUDES), $(CCFLAGS) $(GENERATED_PROJECT_INCLUDES))
+$(call generate_project,mbed,$(MBED_PROJECT_FILES),$(1),$(MICROLITE_CC_SRCS) $(THIRD_PARTY_CC_SRCS) $(2),$(MICROLITE_CC_HDRS) $(THIRD_PARTY_CC_HDRS) $(MICROLITE_TEST_HDRS) $(3),$(MICROLITE_LIBS),$(CXXFLAGS),$(CCFLAGS))
+$(call generate_project,keil,$(KEIL_PROJECT_FILES),$(1),$(MICROLITE_CC_SRCS) $(THIRD_PARTY_CC_SRCS) $(2),$(MICROLITE_CC_HDRS) $(THIRD_PARTY_CC_HDRS) $(MICROLITE_TEST_HDRS) $(3),$(MICROLITE_LIBS),$(CXXFLAGS),$(CCFLAGS))
endef
diff --git a/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531/_main.c b/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531/_main.c
index 2764f3ba..25d3e7c 100644
--- a/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531/_main.c
+++ b/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531/_main.c
@@ -51,12 +51,6 @@
//*****************************************************************************
extern int main(int argc, char** argv);
-void DebugLog(const char* s) { EtaCspIoPrintf("%s", s); }
-void DebugLogInt32(int32_t i) { EtaCspIoPrintf("%d", i); }
-void DebugLogUInt32(uint32_t i) { EtaCspIoPrintf("%d", i); }
-void DebugLogHex(uint32_t i) { EtaCspIoPrintf("0x%8x", i); }
-void DebugLogFloat(float i) { EtaCspIoPrintf("%f", i); }
-
int _main(void) {
uint64_t time_ms;
diff --git a/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531_makefile.inc b/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531_makefile.inc
index baae58f..4ce2f69 100644
--- a/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531_makefile.inc
+++ b/tensorflow/lite/experimental/micro/tools/make/targets/ecm3531_makefile.inc
@@ -14,6 +14,7 @@
endif
PLATFORM_FLAGS = \
+ -DARM_MATH_CM3 \
-DFIRMWARE_BUILD \
-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK \
-DTF_LITE_STATIC_MEMORY \
@@ -59,13 +60,12 @@
-fno-exceptions \
-nostdlib --specs=nano.specs -t -lstdc++ -lc -lnosys -lm \
-Wl,-T,$(MAKEFILE_DIR)/targets/ecm3531/ecm3531.lds \
- -Wl,-Map=$(MAKEFILE_DIR)/gen/$(TARGET).map,--cref
+ -Wl,-Map=$(MAKEFILE_DIR)/targets/ecm3531/ecm3531.map,--cref
BUILD_TYPE := micro
MICROLITE_LIBS := \
$(GCC_ARM)/lib/gcc/arm-none-eabi/7.3.1/thumb/v7e-m/fpv4-sp/softfp/crtbegin.o \
-lm
- INCLUDES += \
- -isystem$(MAKEFILE_DIR)/downloads/cmsis/CMSIS/Core/Include/ \
+ ECM3531_INCLUDES := \
-I$(GCC_ARM)/arm-none-eabi/include/ \
-I$(ETA_SDK)/ecm3531/boards/eta_evb/projects/m3/common/inc/ \
-I$(ETA_SDK)/ecm3531/m3/reg/inc/ \
@@ -75,6 +75,9 @@
-I$(ETA_SDK)/../utils/inc/ \
-I$(ETA_SDK)/ecm3531/boards/eta_evb/eta_bsp/inc
+ INCLUDES += $(ECM3531_INCLUDES)
+ GENERATED_PROJECT_INCLUDES += $(ECM3531_INCLUDES)
+
# _main.c contains application and target specific initialization, like
# setting clock speed, default uart setups, etc. and an implementation
# of the DebugLog interfaces.
@@ -83,7 +86,13 @@
$(MAKEFILE_DIR)/targets/ecm3531/_main.c \
$(wildcard $(ETA_SDK)/ecm3531/boards/eta_evb/projects/m3/common/src/*.c) \
$(wildcard $(ETA_SDK)/ecm3531/m3/csp/src/*.c) \
- $(wildcard $(ETA_SDK)/ecm3531/m3/csp/src/*.s) \
+ $(wildcard $(ETA_SDK)/ecm3531/m3/csp/src/*.s)
+
+ # The linker script isn't a header, but it needs to get copied to the gen/
+ # directory for generated projects. This is similar to the behavior needed
+ # for headers.
+ MICROLITE_CC_HDRS += \
+ $(MAKEFILE_DIR)/targets/ecm3531/ecm3531.lds
TEST_SCRIPT := tensorflow/lite/experimental/micro/testing/test_ecm3531_binary.sh
# These are tests that don't currently work on the blue pill.
diff --git a/tensorflow/lite/experimental/micro/tools/make/templates/Makefile.tpl b/tensorflow/lite/experimental/micro/tools/make/templates/Makefile.tpl
index 74d54f1..ca6519c 100644
--- a/tensorflow/lite/experimental/micro/tools/make/templates/Makefile.tpl
+++ b/tensorflow/lite/experimental/micro/tools/make/templates/Makefile.tpl
@@ -4,12 +4,8 @@
OBJS := \
$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(SRCS)))
-INCLUDES := \
--I. \
--I./third_party/gemmlowp \
--I./third_party/flatbuffers/include
-
CXXFLAGS += %{CXX_FLAGS}%
+CCFLAGS += %{CC_FLAGS}%
LDFLAGS += %{LINKER_FLAGS}%
@@ -20,7 +16,6 @@
$(CC) $(CCFLAGS) $(INCLUDES) -c $< -o $@
%{EXECUTABLE}% : $(OBJS)
- $(CXX) $(LDFLAGS) $(OBJS) \
- -o $@
+ $(CXX) $(CXXFLAGS) -o $@ $(OBJS) $(LDFLAGS)
all: %{EXECUTABLE}%
diff --git a/tensorflow/lite/g3doc/demo_ios.md b/tensorflow/lite/g3doc/demo_ios.md
index f4b481d..33e74f1 100644
--- a/tensorflow/lite/g3doc/demo_ios.md
+++ b/tensorflow/lite/g3doc/demo_ios.md
@@ -1,9 +1,10 @@
-
# iOS Demo App
-The TensorFlow Lite demo is a camera app that continuously classifies whatever
-it sees from your device's back camera, using a quantized MobileNet model. These
-instructions walk you through building and running the demo on an iOS device.
+This tutorial provides a simple iOS mobile application to classify images using
+the iOS device camera. In this tutorial, you will download the demo application
+from the Tensorflow repository, build it on your computer, and install it on
+your iOS Device. You will also learn how to customize the application to suit
+your requirements.
## Prerequisites
@@ -30,47 +31,199 @@
If this is a new install, you will need to run the Xcode application once to
agree to the license before continuing.
-## Building the iOS Demo App
-
-1. Install CocoaPods if you don't have it:
+* Install CocoaPods if you don't have it:
sudo gem install cocoapods
-2. Download the model files used by the demo app (this is done from inside the
- cloned directory):
+### Step 1. Clone the TensorFlow source code
- sh tensorflow/lite/examples/ios/download_models.sh
+First, we clone the GitHub repository on the computer in a folder to get the
+demo application.
-3. Install the pod to generate the workspace file:
+```
+git clone https://github.com/tensorflow/tensorflow
+```
- cd tensorflow/lite/examples/ios/camera
- pod install
+### Step 2. Download required dependencies
- If you have installed this pod before and that command doesn't work, try
+Execute the shell script to download the model files used by the demo app (this
+is done from inside the cloned directory):
- pod repo update
+```
+ tensorflow/lite/examples/ios/download_models.sh
+```
- At the end of this step you should have a file called
- `tflite_camera_example.xcworkspace`.
+Run the following command to install TensorFlow Lite pod:
-4. Open the project in Xcode by typing this on the command line:
+```
+ cd tensorflow/lite/examples/ios/camera
+ pod install
+```
- open tflite_camera_example.xcworkspace
+If you have installed this pod before and that command doesn't work, try
- This launches Xcode if it isn't open already and opens the
- `tflite_camera_example` project.
+```
+ pod repo update
+```
-5. Under `Project navigator -> tflite_camera_example -> Targets ->
- tflite_camera_example -> General` change the bundle identifier by
- pre-pending your name:
+### Step 3. Build the XCode project
- 
+Open the `tflite_camera_example.xcworkspace` project file generated in the last
+step:
-6. Build and run the app in Xcode.
+```
+ open tflite_camera_example.xcworkspace
+```
- Note that as mentioned earlier, you must already have a device set up and
- linked to your Apple Developer account in order to deploy the app on a
- device.
+Under `Project navigator -> tflite_camera_example -> Targets ->
+tflite_camera_example -> General` change the bundle identifier by pre-pending
+your name:
+
+
+
+Plug in your iOS device. Note the app must be executed with a real device with
+camera. Select the iOS device from the drop-down menu.
+
+
+
+Click the "Run" button to build and run the app
+
+
+
+Note that as mentioned earlier, you must already have a device set up and linked
+to your Apple Developer account in order to deploy the app on a device.
You'll have to grant permissions for the app to use the device's camera. Point
the camera at various objects and enjoy seeing how the model classifies things!
+
+## Understanding iOS App Code
+
+### Get camera input
+
+The main logic of this app is in the Objective C++ source file
+`tensorflow/lite/examples/ios/camera/CameraExampleViewController.mm`.
+
+The `setupAVCapture` method constructs a `AVCaptureSession` and set itself as a
+delegate. The `captureOutput:didOutputSampleBuffer:fromConnection:` method is
+called for every captured frame. It calls `runModelOnFrame` to run the model for
+every frame.
+
+### Create an interpreter
+
+To create the interpreter, we need to load the model file. The following code
+will load a model and create an interpreter.
+
+```
+model = tflite::FlatBufferModel::BuildFromFile([graph_path UTF8String]);
+```
+
+Behind the scenes, the model is loaded as a memory-mapped file. It offers faster
+load times and reduce the dirty pages in memory.
+
+Construct a `BuiltinOpResolver` to use the TensorFliw Lite buildin ops. Then,
+create the interpreter object using `InterpreterBuilder` that takes the model
+file as argument as shown below.
+
+```
+tflite::ops::builtin::BuiltinOpResolver resolver;
+tflite::InterpreterBuilder(*model, resolver)(&interpreter);
+```
+
+### Obtain the input buffer
+
+By default, the app uses quantized model since it's smaller and faster. The
+buffer is a raw pointer to an array of 8 bit unsigned integers (`uint8_t`). The
+following code obtains the input buffer from the interpreter:
+
+```
+// Get the index of first input tensor.
+int input_tensor_index = interpreter->inputs()[0];
+// Get the pointer to the input buffer.
+uint8_t* buffer = interpreter->typed_tensor<uint8_t>(input_tensor_index);
+```
+
+Throughout this document, it's assumed a quantized model is used.
+
+### Pre-process of bitmap image
+
+The MobileNet model we're using takes 224x224x3 inputs, where the dimensions are
+width, height, and colors (RGB). The images returned from `AVCaptureSession` is
+bigger, and has 4 color channels (RGBA).
+
+Many image classification models (like MobileNet) take fixe-sized inputs. It's
+required to scale or crop the image before feeding it into the model, and change
+the channels from RGBA to RGB.
+
+The code to pre-process the images is in `ProcessInputWithQuantizedModel`
+function in
+`tensorflow/lite/examples/ios/camera/CameraExampleViewController.mm`. It's a
+simple implementation for nearest neighbor color sampling, and it only copies
+the first 3 bytes for each pixel.
+
+```
+void ProcessInputWithQuantizedModel(
+ uint8_t* input, uint8_t* output, int image_width, int image_height, int image_channels) {
+ for (int y = 0; y < wanted_input_height; ++y) {
+ uint8_t* out_row = output + (y * wanted_input_width * wanted_input_channels);
+ for (int x = 0; x < wanted_input_width; ++x) {
+ const int in_x = (y * image_width) / wanted_input_width;
+ const int in_y = (x * image_height) / wanted_input_height;
+ uint8_t* in_pixel = input + (in_y * image_width * image_channels) + (in_x * image_channels);
+ uint8_t* out_pixel = out_row + (x * wanted_input_channels);
+ for (int c = 0; c < wanted_input_channels; ++c) {
+ out_pixel[c] = in_pixel[c];
+ }
+ }
+ }
+}
+```
+
+Note the code is preprocessing and preparing the model input from the camera
+data. Therefore the first parameter `input` should be the camera buffer. The
+second parameter `output` should be the buffer of model input.
+
+### Run inference and obtain output buffer
+
+After preprocessing and filling the data into the input buffer of the
+interpreter, it's really easy to run the interpreter:
+
+```
+if (interpreter->Invoke() != kTfLiteOk) {
+ NSLog("Failed to invoke!");
+}
+```
+
+The result is stored in the output tensor buffer of the interpreter. The
+following code obtains the pointer to the buffer:
+
+```
+// Get the index of first output tensor.
+const int output_tensor_index = interpreter->outputs()[0];
+// Get the pointer to the output buffer.
+uint8_t* buffer = interpreter->typed_tensor<uint8_t>(output_tensor_index);
+```
+
+### Post-process values
+
+The output buffer contains an array of `uint8_t`, and the value range is 0-255.
+We need to convert the value to float to get the probabilities with value range
+0.0-1.0. The formula of the quantization value mapping is:
+
+ float_value = (quantized_value - zero_point) * scale
+
+The following code converts quantized values back to float values, using the
+quantizaiton parameters in tensors:
+
+```
+uint8_t* quantized_output = interpreter->typed_output_tensor<uint8_t>(0);
+int32_t zero_point = input_tensor->params.zero_point;
+float scale = input_tensor->params.scale;
+float output[output_size];
+for (int i = 0; i < output_size; ++i) {
+ output[i] = (quantized_output[i] - zero_point) * scale;
+}
+```
+
+Finally, we find the best set of classifications by storing them in a priority
+queue based on their confidence scores. See the `GetTopN` function in
+`tensorflow/lite/examples/ios/camera/CameraExampleViewController.mm`.
diff --git a/tensorflow/lite/g3doc/images/ios/build_and_execute.png b/tensorflow/lite/g3doc/images/ios/build_and_execute.png
new file mode 100644
index 0000000..a305350
--- /dev/null
+++ b/tensorflow/lite/g3doc/images/ios/build_and_execute.png
Binary files differ
diff --git a/tensorflow/lite/g3doc/images/ios/device_selection.png b/tensorflow/lite/g3doc/images/ios/device_selection.png
new file mode 100644
index 0000000..1565fa0
--- /dev/null
+++ b/tensorflow/lite/g3doc/images/ios/device_selection.png
Binary files differ
diff --git a/tensorflow/lite/g3doc/tf_ops_compatibility.md b/tensorflow/lite/g3doc/tf_ops_compatibility.md
index cff4afc..d7c71df 100644
--- a/tensorflow/lite/g3doc/tf_ops_compatibility.md
+++ b/tensorflow/lite/g3doc/tf_ops_compatibility.md
@@ -165,6 +165,17 @@
}
```
+**ADD_N**
+
+```
+Inputs {
+ 0-N: any number of tensors (must have same size and shape)
+}
+Outputs {
+ 0: elementwise sum of the input tensors
+}
+```
+
**ARG_MAX**
```
diff --git a/tensorflow/lite/interpreter.cc b/tensorflow/lite/interpreter.cc
index 7a6a074..c840c9a 100644
--- a/tensorflow/lite/interpreter.cc
+++ b/tensorflow/lite/interpreter.cc
@@ -91,7 +91,7 @@
}
void Interpreter::ReserveNodes(int count) {
- primary_subgraph().nodes_and_registration().reserve(count);
+ primary_subgraph().ReserveNodes(count);
}
void Interpreter::AddSubgraphs(int subgraphs_to_add,
@@ -223,7 +223,12 @@
}
TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) {
- return primary_subgraph().ModifyGraphWithDelegate(delegate);
+ // TODO(ycling): It seems Flex delegate doesn't work on non-primary subgraphs.
+ // Need to investigate.
+ for (auto& subgraph : subgraphs_) {
+ TF_LITE_ENSURE_OK(context_, subgraph->ModifyGraphWithDelegate(delegate));
+ }
+ return kTfLiteOk;
}
TfLiteStatus Interpreter::SetBufferHandle(int tensor_index,
diff --git a/tensorflow/lite/kernels/BUILD b/tensorflow/lite/kernels/BUILD
index d338bc5..795bd80 100644
--- a/tensorflow/lite/kernels/BUILD
+++ b/tensorflow/lite/kernels/BUILD
@@ -152,6 +152,7 @@
srcs = [
"activations.cc",
"add.cc",
+ "add_n.cc",
"arg_min_max.cc",
"audio_spectrogram.cc",
"basic_rnn.cc",
@@ -180,8 +181,8 @@
"fully_connected.cc",
"gather.cc",
"hashtable_lookup.cc",
+ "if.cc",
"l2norm.cc",
- "layer_norm_lstm.cc",
"local_response_norm.cc",
"logical.cc",
"lsh_projection.cc",
@@ -198,7 +199,6 @@
"pow.cc",
"range.cc",
"reduce.cc",
- "relu1.cc",
"reshape.cc",
"resize_bilinear.cc",
"resize_nearest_neighbor.cc",
@@ -333,19 +333,6 @@
)
tf_cc_test(
- name = "relu1_test",
- size = "small",
- srcs = ["relu1_test.cc"],
- deps = [
- ":builtin_ops",
- "//tensorflow/lite:framework",
- "//tensorflow/lite/kernels:test_util",
- "@com_google_googletest//:gtest",
- "@flatbuffers",
- ],
-)
-
-tf_cc_test(
name = "activations_test",
size = "small",
srcs = ["activations_test.cc"],
@@ -370,6 +357,18 @@
)
tf_cc_test(
+ name = "add_n_test",
+ size = "small",
+ srcs = ["add_n_test.cc"],
+ deps = [
+ ":builtin_ops",
+ ":test_util",
+ "//tensorflow/lite:framework",
+ "@com_google_googletest//:gtest",
+ ],
+)
+
+tf_cc_test(
name = "arg_min_max_test",
size = "small",
srcs = ["arg_min_max_test.cc"],
@@ -879,19 +878,6 @@
)
tf_cc_test(
- name = "layer_norm_lstm_test",
- size = "small",
- srcs = ["layer_norm_lstm_test.cc"],
- deps = [
- ":builtin_ops",
- "//tensorflow/lite:framework",
- "//tensorflow/lite/kernels:test_util",
- "@com_google_googletest//:gtest",
- "@flatbuffers",
- ],
-)
-
-tf_cc_test(
name = "lstm_test",
size = "small",
srcs = ["lstm_test.cc"],
@@ -1217,6 +1203,7 @@
srcs = ["squared_difference_test.cc"],
deps = [
":builtin_ops",
+ "//tensorflow/lite:builtin_op_data",
"//tensorflow/lite:framework",
"//tensorflow/lite/kernels:test_util",
"@com_google_googletest//:gtest",
@@ -1224,6 +1211,23 @@
)
tf_cc_test(
+ name = "if_test",
+ size = "small",
+ srcs = ["if_test.cc"],
+ tags = ["tflite_not_portable_ios"],
+ deps = [
+ ":builtin_ops",
+ ":kernel_util",
+ ":subgraph_test_util",
+ ":test_util",
+ "//tensorflow/lite:builtin_op_data",
+ "//tensorflow/lite:framework",
+ "@com_google_googletest//:gtest",
+ "@flatbuffers",
+ ],
+)
+
+tf_cc_test(
name = "fill_test",
size = "small",
srcs = ["fill_test.cc"],
@@ -1282,3 +1286,31 @@
"@com_google_googletest//:gtest_main",
],
)
+
+cc_library(
+ name = "subgraph_test_util",
+ testonly = 1,
+ srcs = ["subgraph_test_util.cc"],
+ hdrs = ["subgraph_test_util.h"],
+ deps = [
+ ":builtin_ops",
+ ":kernel_util",
+ ":test_util",
+ "//tensorflow/lite:builtin_op_data",
+ "//tensorflow/lite:framework",
+ "@com_google_googletest//:gtest",
+ "@flatbuffers",
+ ],
+)
+
+tf_cc_test(
+ name = "subgraph_test_util_test",
+ size = "small",
+ srcs = ["subgraph_test_util_test.cc"],
+ deps = [
+ ":subgraph_test_util",
+ "//tensorflow/lite:framework",
+ "//tensorflow/lite/kernels:test_util",
+ "@com_google_googletest//:gtest",
+ ],
+)
diff --git a/tensorflow/lite/kernels/activations.cc b/tensorflow/lite/kernels/activations.cc
index 4463a6c..2b35cc4 100644
--- a/tensorflow/lite/kernels/activations.cc
+++ b/tensorflow/lite/kernels/activations.cc
@@ -60,9 +60,9 @@
TfLiteStatus CheckOutputQuantParams(TfLiteContext* context,
const TfLiteTensor* input,
const TfLiteTensor* output) {
+ TF_LITE_ENSURE(context, output->params.scale == 1. / 256);
if (input->type == kTfLiteUInt8) {
TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
- TF_LITE_ENSURE(context, output->params.scale == 1. / 256);
} else {
TF_LITE_ENSURE_EQ(context, output->params.zero_point, -128);
}
diff --git a/tensorflow/lite/kernels/add_n.cc b/tensorflow/lite/kernels/add_n.cc
new file mode 100644
index 0000000..3e9b2ea
--- /dev/null
+++ b/tensorflow/lite/kernels/add_n.cc
@@ -0,0 +1,88 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/lite/c/c_api_internal.h"
+#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/lite/kernels/internal/tensor.h"
+#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace add_n {
+
+constexpr int kInputTensor1 = 0;
+constexpr int kOutputTensor = 0;
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ int num_inputs = NumInputs(node);
+ TF_LITE_ENSURE(context, num_inputs >= 2);
+ TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ output->type = input1->type;
+
+ // Check that all input tensors have the same shape and type.
+ for (int i = kInputTensor1 + 1; i < num_inputs; ++i) {
+ const TfLiteTensor* input = GetInput(context, node, i);
+ TF_LITE_ENSURE(context, HaveSameShapes(input1, input));
+ TF_LITE_ENSURE_EQ(context, input1->type, input->type);
+ }
+
+ // Use the first input node's dimension to be the dimension of the output
+ // node.
+ TfLiteIntArray* input1_dims = input1->dims;
+ TfLiteIntArray* output_dims = TfLiteIntArrayCopy(input1_dims);
+ return context->ResizeTensor(context, output, output_dims);
+}
+
+template <typename T>
+void EvalAddN(TfLiteContext* context, TfLiteNode* node) {
+ // TODO(haoliang): Initialize all_inputs only once during init.
+ VectorOfTensors<T> all_inputs(*context, *node->inputs);
+ TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ int num_inputs = NumInputs(node);
+ const TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+ reference_ops::AddN<T>(GetTensorShape(input1), num_inputs, all_inputs.data(),
+ GetTensorData<T>(output));
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+ if (output->type == kTfLiteFloat32) {
+ EvalAddN<float>(context, node);
+ } else if (output->type == kTfLiteInt32) {
+ EvalAddN<int32_t>(context, node);
+ } else {
+ context->ReportError(context,
+ "AddN only supports FLOAT32|INT32 now, got %s.",
+ TfLiteTypeGetName(output->type));
+ return kTfLiteError;
+ }
+ return kTfLiteOk;
+}
+
+} // namespace add_n
+
+TfLiteRegistration* Register_ADD_N() {
+ static TfLiteRegistration r = {/*init*/ nullptr, /*free*/ nullptr,
+ add_n::Prepare, add_n::Eval};
+ return &r;
+}
+
+} // namespace builtin
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/lite/kernels/add_n_test.cc b/tensorflow/lite/kernels/add_n_test.cc
new file mode 100644
index 0000000..ee9477d
--- /dev/null
+++ b/tensorflow/lite/kernels/add_n_test.cc
@@ -0,0 +1,98 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <vector>
+
+#include <gtest/gtest.h>
+#include "tensorflow/lite/interpreter.h"
+#include "tensorflow/lite/kernels/register.h"
+#include "tensorflow/lite/kernels/test_util.h"
+#include "tensorflow/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class BaseAddNOpModel : public SingleOpModel {
+ public:
+ BaseAddNOpModel(const std::vector<TensorData>& inputs,
+ const TensorData& output) {
+ int num_inputs = inputs.size();
+ std::vector<std::vector<int>> input_shapes;
+
+ for (int i = 0; i < num_inputs; ++i) {
+ inputs_.push_back(AddInput(inputs[i]));
+ input_shapes.push_back(GetShape(inputs_[i]));
+ }
+
+ output_ = AddOutput(output);
+ SetBuiltinOp(BuiltinOperator_ADD_N, BuiltinOptions_AddNOptions,
+ CreateAddNOptions(builder_).Union());
+ BuildInterpreter(input_shapes);
+ }
+
+ int input(int i) { return inputs_[i]; }
+
+ protected:
+ std::vector<int> inputs_;
+ int output_;
+};
+
+class FloatAddNOpModel : public BaseAddNOpModel {
+ public:
+ using BaseAddNOpModel::BaseAddNOpModel;
+
+ std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+};
+
+class IntegerAddNOpModel : public BaseAddNOpModel {
+ public:
+ using BaseAddNOpModel::BaseAddNOpModel;
+
+ std::vector<int32_t> GetOutput() { return ExtractVector<int32_t>(output_); }
+};
+
+TEST(FloatAddNOpModel, AddMultipleTensors) {
+ FloatAddNOpModel m({{TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {1, 2, 2, 1}},
+ {TensorType_FLOAT32, {1, 2, 2, 1}}},
+ {TensorType_FLOAT32, {}});
+ m.PopulateTensor<float>(m.input(0), {-2.0, 0.2, 0.7, 0.8});
+ m.PopulateTensor<float>(m.input(1), {0.1, 0.2, 0.3, 0.5});
+ m.PopulateTensor<float>(m.input(2), {0.5, 0.1, 0.1, 0.2});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.4, 0.5, 1.1, 1.5}));
+}
+
+TEST(IntegerAddNOpModel, AddMultipleTensors) {
+ IntegerAddNOpModel m({{TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}},
+ {TensorType_INT32, {1, 2, 2, 1}}},
+ {TensorType_INT32, {}});
+ m.PopulateTensor<int32_t>(m.input(0), {-20, 2, 7, 8});
+ m.PopulateTensor<int32_t>(m.input(1), {1, 2, 3, 5});
+ m.PopulateTensor<int32_t>(m.input(2), {10, -5, 1, -2});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput(), ElementsAreArray({-9, -1, 11, 11}));
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/lite/kernels/arg_min_max.cc b/tensorflow/lite/kernels/arg_min_max.cc
index f9adf6b..e5223ba 100644
--- a/tensorflow/lite/kernels/arg_min_max.cc
+++ b/tensorflow/lite/kernels/arg_min_max.cc
@@ -80,13 +80,14 @@
switch (input->type) {
case kTfLiteFloat32:
case kTfLiteUInt8:
+ case kTfLiteInt8:
case kTfLiteInt32:
break;
default:
context->ReportError(
context,
- "Unkonwn input type: %d, only float32 and int types are supported",
+ "Unknown input type: %d, only float32 and int types are supported",
input->type);
return kTfLiteError;
}
@@ -135,6 +136,9 @@
case kTfLiteUInt8:
TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int32_t);
break;
+ case kTfLiteInt8:
+ TF_LITE_ARG_MIN_MAX(int8_t, int32_t, int32_t);
+ break;
case kTfLiteInt32:
TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int32_t);
break;
@@ -150,6 +154,9 @@
case kTfLiteUInt8:
TF_LITE_ARG_MIN_MAX(uint8_t, int32_t, int64_t);
break;
+ case kTfLiteInt8:
+ TF_LITE_ARG_MIN_MAX(int8_t, int32_t, int64_t);
+ break;
case kTfLiteInt32:
TF_LITE_ARG_MIN_MAX(int32_t, int32_t, int64_t);
break;
diff --git a/tensorflow/lite/kernels/arg_min_max_test.cc b/tensorflow/lite/kernels/arg_min_max_test.cc
index 1b1000f..01ea923 100644
--- a/tensorflow/lite/kernels/arg_min_max_test.cc
+++ b/tensorflow/lite/kernels/arg_min_max_test.cc
@@ -86,6 +86,28 @@
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
}
+TEST(ArgMaxOpTest, GetMaxArgUInt8) {
+ ArgMaxOpModel<int32_t> model({1, 1, 1, 4}, TensorType_UINT8, TensorType_INT32,
+ TensorType_INT32);
+ model.PopulateTensor<uint8_t>(model.input(), {1, 9, 7, 3});
+ model.PopulateTensor<int>(model.axis(), {3});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
+}
+
+TEST(ArgMaxOpTest, GetMaxArgInt8) {
+ ArgMaxOpModel<int32_t> model({1, 1, 1, 4}, TensorType_INT8, TensorType_INT32,
+ TensorType_INT32);
+ model.PopulateTensor<int8_t>(model.input(), {-1, -9, 7, 3});
+ model.PopulateTensor<int>(model.axis(), {3});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({2}));
+ EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1}));
+}
+
TEST(ArgMaxOpTest, GetMaxArgInt) {
ArgMaxOpModel<int32_t> model({1, 1, 1, 4}, TensorType_INT32, TensorType_INT32,
TensorType_INT32);
diff --git a/tensorflow/lite/kernels/depthwise_conv.cc b/tensorflow/lite/kernels/depthwise_conv.cc
index 3f4ae50..a349b27 100644
--- a/tensorflow/lite/kernels/depthwise_conv.cc
+++ b/tensorflow/lite/kernels/depthwise_conv.cc
@@ -26,6 +26,7 @@
#include "tensorflow/lite/kernels/internal/quantization_util.h"
#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_float.h"
#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
+#include "tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h"
#include "tensorflow/lite/kernels/op_macros.h"
@@ -58,6 +59,10 @@
// uint8_t these would be 0 and 255.
int32_t output_activation_min;
int32_t output_activation_max;
+
+ // Per channel output multiplier and shift.
+ std::vector<int32_t> per_channel_output_multiplier;
+ std::vector<int> per_channel_output_shift;
};
void* Init(TfLiteContext* context, const char* buffer, size_t length) {
@@ -99,14 +104,15 @@
SizeOfDimension(filter, 3));
const TfLiteType data_type = input->type;
- TF_LITE_ENSURE(context,
- data_type == kTfLiteFloat32 || data_type == kTfLiteUInt8);
+ TF_LITE_ENSURE(context, data_type == kTfLiteFloat32 ||
+ data_type == kTfLiteUInt8 ||
+ data_type == kTfLiteInt8);
TF_LITE_ENSURE_EQ(context, output->type, data_type);
TF_LITE_ENSURE_EQ(context, filter->type, data_type);
if (hasBias) {
bias = GetInput(context, node, kBiasTensor);
- if (data_type == kTfLiteUInt8) {
+ if (data_type == kTfLiteUInt8 || data_type == kTfLiteInt8) {
TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteInt32);
TF_LITE_ENSURE_EQ(context, bias->params.zero_point, 0);
} else {
@@ -150,17 +156,25 @@
filter_width, out_width);
// Note that quantized inference requires that all tensors have their
- // parameters set. This is usually done during quantized training.
+ // parameters set. This is usually done during quantized training or
+ // calibration.
if (data_type != kTfLiteFloat32) {
- double real_multiplier = 0.0;
- TF_LITE_ENSURE_STATUS(GetQuantizedConvolutionMultipler(
- context, input, filter, bias, output, &real_multiplier));
- int exponent;
- QuantizeMultiplier(real_multiplier, &data->output_multiplier, &exponent);
- data->output_shift = -exponent;
- CalculateActivationRangeUint8(params->activation, output,
- &data->output_activation_min,
- &data->output_activation_max);
+ TF_LITE_ENSURE_EQ(context, filter->quantization.type,
+ kTfLiteAffineQuantization);
+ const auto* affine_quantization =
+ reinterpret_cast<TfLiteAffineQuantization*>(
+ filter->quantization.params);
+ TF_LITE_ENSURE(context, affine_quantization);
+ TF_LITE_ENSURE(context, affine_quantization->scale);
+ const int number_channel = affine_quantization->scale->size;
+ data->per_channel_output_multiplier.resize(number_channel);
+ data->per_channel_output_shift.resize(number_channel);
+ TF_LITE_ENSURE_STATUS(tflite::PopulateConvolutionQuantizationParams(
+ context, input, filter, bias, output, params->activation,
+ &data->output_multiplier, &data->output_shift,
+ &data->output_activation_min, &data->output_activation_max,
+ data->per_channel_output_multiplier.data(),
+ data->per_channel_output_shift.data()));
}
TfLiteIntArray* outputSize = TfLiteIntArrayCreate(4);
@@ -250,6 +264,33 @@
GetTensorData<uint8_t>(output));
}
+void EvalQuantizedPerChannel(TfLiteContext* context, TfLiteNode* node,
+ TfLiteDepthwiseConvParams* params, OpData* data,
+ const TfLiteTensor* input,
+ const TfLiteTensor* filter,
+ const TfLiteTensor* bias, TfLiteTensor* output) {
+ DepthwiseParams op_params;
+ op_params.padding_type = PaddingType::kSame;
+ op_params.padding_values.width = data->padding.width;
+ op_params.padding_values.height = data->padding.height;
+ op_params.stride_width = params->stride_width;
+ op_params.stride_height = params->stride_height;
+ op_params.dilation_width_factor = params->dilation_width_factor;
+ op_params.dilation_height_factor = params->dilation_height_factor;
+ op_params.depth_multiplier = params->depth_multiplier;
+ op_params.input_offset = input->params.zero_point;
+ op_params.weights_offset = 0;
+ op_params.output_offset = output->params.zero_point;
+
+ reference_integer_ops::DepthwiseConvPerChannel(
+ op_params, data->per_channel_output_multiplier.data(),
+ data->per_channel_output_shift.data(), GetTensorShape(input),
+ GetTensorData<int8>(input), GetTensorShape(filter),
+ GetTensorData<int8>(filter), GetTensorShape(bias),
+ GetTensorData<int32>(bias), GetTensorShape(output),
+ GetTensorData<int8>(output));
+}
+
template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
auto* params =
@@ -273,6 +314,11 @@
EvalQuantized<kernel_type>(context, node, params, data, input, filter,
bias, output);
break;
+ case kTfLiteInt8: {
+ EvalQuantizedPerChannel(context, node, params, data, input, filter, bias,
+ output);
+ break;
+ }
default:
context->ReportError(context, "Type %d not currently supported.",
input->type);
diff --git a/tensorflow/lite/kernels/depthwise_conv_test.cc b/tensorflow/lite/kernels/depthwise_conv_test.cc
index 75aed4c..5dc5132 100644
--- a/tensorflow/lite/kernels/depthwise_conv_test.cc
+++ b/tensorflow/lite/kernels/depthwise_conv_test.cc
@@ -56,9 +56,35 @@
// This is a quantized version. The scale of 'bias' depends on the scales
// of input and filter. Supposedly this is correctly set during quantized
// training.
- auto bias_scale = GetScale(input_) * GetScale(filter_);
- TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale};
- bias_ = AddInput(bias);
+ if (filter.per_channel_quantization) {
+ // per channel quantization.
+ std::vector<float> bias_scale(
+ filter.per_channel_quantization_scales.size());
+ std::vector<int64_t> bias_zero_points(
+ filter.per_channel_quantization_scales.size());
+ for (int i = 0; i < filter.per_channel_quantization_scales.size();
+ ++i) {
+ bias_scale[i] =
+ input.scale * filter.per_channel_quantization_scales[i];
+ bias_zero_points[i] = 0;
+ }
+ TensorData bias{TensorType_INT32,
+ {bias_size},
+ /*min=*/0,
+ /*max=*/0,
+ /*scale=*/0,
+ /*zero_point=*/0,
+ true,
+ /*per_channel_scale=*/bias_scale,
+ /*per_channel_zero_point=*/bias_zero_points,
+ /*channel_index==*/0};
+ bias_ = AddInput(bias);
+ } else {
+ // per tensor quantization.
+ auto bias_scale = GetScale(input_) * GetScale(filter_);
+ TensorData bias{TensorType_INT32, {bias_size}, 0, 0, bias_scale};
+ bias_ = AddInput(bias);
+ }
}
output_ = AddOutput(output);
@@ -437,6 +463,76 @@
ElementsAreArray({4, 7, 3, 6, 10, 4, 2, 3, 1}));
}
+class PerChannelQuantizedDepthwiseConvolutionOpModel
+ : public BaseDepthwiseConvolutionOpModel {
+ public:
+ using BaseDepthwiseConvolutionOpModel::BaseDepthwiseConvolutionOpModel;
+
+ void SetInput(std::initializer_list<float> data) {
+ QuantizeAndPopulate<int8_t>(input_, data);
+ }
+
+ void SetFilter(std::initializer_list<float> data) {
+ PerChannelSymmetricQuantizeAndPopulate(filter_, data);
+ }
+
+ void SetBias(std::initializer_list<float> data) {
+ PerChannelQuantizeBias(bias_, data);
+ }
+
+ std::vector<int8_t> GetOutput() { return ExtractVector<int8_t>(output_); }
+ std::vector<float> GetDequantizedOutput() {
+ return Dequantize<int8_t>(ExtractVector<int8_t>(output_), GetScale(output_),
+ GetZeroPoint(output_));
+ }
+};
+
+TEST_P(QuantizedDepthwiseConvolutionOpTest, SimpleTest) {
+ PerChannelQuantizedDepthwiseConvolutionOpModel m(
+ GetRegistration(), {TensorType_INT8, {1, 2, 3, 2}, -63.5, 64, 0.5, -1},
+ {TensorType_INT8,
+ // [1 * 2 * 2 * 4] as [input_channel, y, x, output_channel]
+ {1, 2, 2, 4},
+ 0,
+ 0,
+ 0,
+ 0,
+ /*per_channel=*/true,
+ /*per_channel_scales=*/{1, 2, 3, 4},
+ /*per_channel_zeros=*/{0, 0, 0, 0},
+ /*channel_index=*/3},
+ {TensorType_INT8, {}, -63.5, 64, 0.5, -1}, Padding_VALID);
+ m.SetInput({
+ // [1 * 2 * 3 * 2] as [batch, y, x, input_channel]
+ 3, 2, // batch = 0, y = 0, x = 0
+ 1, -1, // batch = 0, y = 0, x = 1
+ -2, -3, // batch = 0, y = 0, x = 2
+ 4, 3, // batch = 0, y = 1, x = 0
+ 2, -2, // batch = 0, y = 1, x = 1
+ -3, -4, // batch = 0, y = 1, x = 2
+ });
+ m.SetFilter(
+ /*filter data*/
+ {
+ // [1 * 2 * 2 * 4] as [input_channel, y, x, output_channel]
+ // depth multiplier = 2
+ 1, 2, 3, 4, // y = 0, x = 0
+ 3, 4, 5, 6, // y = 0, x = 1
+ 7, 8, 5, 6, // y = 1, x = 0
+ 3, 4, 1, 2, // y = 1, x = 1
+ });
+ m.SetBias({3, -2, 4, 6});
+
+ // Invoke and verify output.
+ // output has dimension [1 * 1 * 2 * 4] as [batch, y, x, output_channel]
+ m.Invoke();
+ EXPECT_THAT(
+ m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear({40.5, 48, 27, 40, 0.5, -4, -24, -36})));
+ EXPECT_THAT(m.GetOutput(),
+ ElementsAreArray({80, 95, 53, 79, 0, -9, -49, -73}));
+}
+
INSTANTIATE_TEST_SUITE_P(
DepthwiseConvolutionOpTest, DepthwiseConvolutionOpTest,
::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap)));
diff --git a/tensorflow/lite/kernels/fully_connected.cc b/tensorflow/lite/kernels/fully_connected.cc
index dfc9550..7b4d29c 100644
--- a/tensorflow/lite/kernels/fully_connected.cc
+++ b/tensorflow/lite/kernels/fully_connected.cc
@@ -212,7 +212,9 @@
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE(context,
filter->type == kTfLiteUInt8 || filter->type == kTfLiteInt8);
- TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteFloat32);
+ if (bias) {
+ TF_LITE_ENSURE_EQ(context, bias->type, kTfLiteFloat32);
+ }
TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32);
int total_input_size = 1;
diff --git a/tensorflow/lite/kernels/fully_connected_test.cc b/tensorflow/lite/kernels/fully_connected_test.cc
index 03f4ea7..31aa3f3 100644
--- a/tensorflow/lite/kernels/fully_connected_test.cc
+++ b/tensorflow/lite/kernels/fully_connected_test.cc
@@ -137,6 +137,7 @@
BaseFullyConnectedOpModel(
TfLiteRegistration* registration, int units, int batches,
const TensorData& input, const TensorData& output = {TensorType_FLOAT32},
+ bool bias_tensor_optional = false,
ActivationFunctionType activation_func = ActivationFunctionType_RELU,
FullyConnectedOptionsWeightsFormat weights_format =
FullyConnectedOptionsWeightsFormat_DEFAULT)
@@ -151,7 +152,9 @@
weights_ =
AddInput({input.type, {units_, input_size_}, input.min, input.max});
- if (input.type == TensorType_FLOAT32) {
+ if (bias_tensor_optional) {
+ bias_ = AddNullInput();
+ } else if (input.type == TensorType_FLOAT32) {
bias_ = AddInput({TensorType_FLOAT32, {units_}});
} else {
// This is a quantized version. The scale of 'bias' depends on the scales
@@ -173,7 +176,9 @@
.Union());
resolver_ = absl::make_unique<SingleOpResolver>(
BuiltinOperator_FULLY_CONNECTED, registration);
- BuildInterpreter({GetShape(input_), GetShape(weights_), GetShape(bias_)});
+ BuildInterpreter(
+ {GetShape(input_), GetShape(weights_),
+ (bias_ == kOptionalTensor) ? std::vector<int>() : GetShape(bias_)});
}
int input_size() { return input_size_; }
@@ -397,6 +402,27 @@
EXPECT_THAT(m.GetOutput(), ElementsAre(11, 9));
}
+TEST(FloatFullyConnectedOpTest, SimpleTestNoBias) {
+ // The optimized kernel assumes that the bias is specified.
+ FloatFullyConnectedOpModel m(ops::builtin::Register_FULLY_CONNECTED_PIE(),
+ /*units=*/1, /*batches=*/2,
+ /*input=*/{TensorType_FLOAT32, {2, 2}},
+ /*output=*/{TensorType_FLOAT32},
+ /*bias_tensor_optional=*/true);
+ m.SetWeights({
+ 2, 4, // u = 0
+ });
+
+ m.SetInput({
+ 1, 2, // b = 0
+ 2, 1, // b = 1
+ });
+
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput(), ElementsAre(10, 8));
+}
+
TEST_P(QuantizedFullyConnectedOpTest, SimpleTestQuantized) {
QuantizedFullyConnectedOpModel m(
GetRegistration(), /*units=*/3, /*batches*/ 2,
@@ -477,6 +503,7 @@
/*input=*/
{TensorType_UINT8, {batches, input_depth}, kInputMin, kInputMax},
/*output=*/{TensorType_INT16, {}, kOutputMin, kOutputMax},
+ /*bias_tensor_optional=*/false,
/*activation_func=*/ActivationFunctionType_NONE, weights_format);
std::mt19937 random_engine;
diff --git a/tensorflow/lite/kernels/gather.cc b/tensorflow/lite/kernels/gather.cc
index f205daa..54d05ad 100644
--- a/tensorflow/lite/kernels/gather.cc
+++ b/tensorflow/lite/kernels/gather.cc
@@ -57,6 +57,7 @@
switch (input->type) {
case kTfLiteFloat32:
case kTfLiteUInt8:
+ case kTfLiteInt8:
case kTfLiteInt64:
case kTfLiteInt32:
break;
@@ -135,6 +136,8 @@
return Gather<float, int32_t>(*params, input, positions, output);
case kTfLiteUInt8:
return Gather<uint8_t, int32_t>(*params, input, positions, output);
+ case kTfLiteInt8:
+ return Gather<int8_t, int32_t>(*params, input, positions, output);
case kTfLiteInt32:
return Gather<int32_t, int32_t>(*params, input, positions, output);
case kTfLiteInt64:
@@ -153,6 +156,8 @@
return Gather<float, int64_t>(*params, input, positions, output);
case kTfLiteUInt8:
return Gather<uint8_t, int64_t>(*params, input, positions, output);
+ case kTfLiteInt8:
+ return Gather<int8_t, int64_t>(*params, input, positions, output);
case kTfLiteInt32:
return Gather<int32_t, int64_t>(*params, input, positions, output);
case kTfLiteInt64:
diff --git a/tensorflow/lite/kernels/gather_test.cc b/tensorflow/lite/kernels/gather_test.cc
index 7b5f843..b5461c2 100644
--- a/tensorflow/lite/kernels/gather_test.cc
+++ b/tensorflow/lite/kernels/gather_test.cc
@@ -205,6 +205,24 @@
EXPECT_THAT(m.GetOutput<uint8_t>(), ElementsAreArray({14, 15, 133, 134}));
}
+TEST(TypesGatherOpTest, Int8Int32) {
+ GatherOpModel m({TensorType_INT8, {2, 2}}, {TensorType_INT32, {2}});
+ m.SetInput<int8_t>({-13, -120, 14, 15});
+ m.SetPositions<int32_t>({1, 0});
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray({14, 15, -13, -120}));
+}
+
+TEST(TypesGatherOpTest, Int8Int64) {
+ GatherOpModel m({TensorType_INT8, {2, 2}}, {TensorType_INT64, {2}});
+ m.SetInput<int8_t>({-13, -120, 14, 15});
+ m.SetPositions<int64_t>({1LL, 0LL});
+ m.Invoke();
+
+ EXPECT_THAT(m.GetOutput<int8_t>(), ElementsAreArray({14, 15, -13, -120}));
+}
+
TEST(TypesGatherOpTest, Int64Int32) {
GatherOpModel m({TensorType_INT64, {2, 2}}, {TensorType_INT32, {2}});
m.SetInput<int64_t>({-(1LL << 34), 134LL, 14LL, 15LL});
diff --git a/tensorflow/lite/kernels/if.cc b/tensorflow/lite/kernels/if.cc
new file mode 100644
index 0000000..a814ac9
--- /dev/null
+++ b/tensorflow/lite/kernels/if.cc
@@ -0,0 +1,196 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
+#include "tensorflow/lite/c/builtin_op_data.h"
+#include "tensorflow/lite/c/c_api_internal.h"
+#include "tensorflow/lite/core/subgraph.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace custom {
+namespace if_kernel {
+
+struct OpData {
+ int then_subgraph_index;
+ int else_subgraph_index;
+};
+
+void* Init(TfLiteContext* context, const char* buffer, size_t length) {
+ auto* op_data = new OpData;
+ const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
+ const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
+ op_data->then_subgraph_index = m["then_subgraph_index"].AsInt32();
+ op_data->else_subgraph_index = m["else_subgraph_index"].AsInt32();
+ return op_data;
+}
+
+void Free(TfLiteContext* context, void* buffer) {
+ delete reinterpret_cast<OpData*>(buffer);
+}
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
+
+ TF_LITE_ENSURE(context, node->inputs->size > 0);
+
+ // The first input is the condition.
+ const TfLiteTensor* cond = GetInput(context, node, 0);
+ // Currently only bool is supported.
+ // TODO(ycling): Support other types since TensorFlow also support
+ // non-bool types as condition.
+ TF_LITE_ENSURE_EQ(context, cond->type, kTfLiteBool);
+ TF_LITE_ENSURE_EQ(context, NumElements(cond), 1);
+
+ // The first input of the node is the condition. The rest of inputs are
+ // passed to the branch subgraphs. Therefore, the number of subgraph inputs
+ // will be the number of node inputs - 1.
+ int num_inputs = node->inputs->size - 1;
+ int num_outputs = node->outputs->size;
+
+ Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_);
+ auto* subgraphs = this_subgraph->GetSubgraphs();
+ TF_LITE_ENSURE(context, op_data->then_subgraph_index < subgraphs->size());
+ TF_LITE_ENSURE(context, op_data->else_subgraph_index < subgraphs->size());
+
+ Subgraph* then_subgraph = (*subgraphs)[op_data->then_subgraph_index].get();
+ Subgraph* else_subgraph = (*subgraphs)[op_data->else_subgraph_index].get();
+
+ for (auto* subgraph : {then_subgraph, else_subgraph}) {
+ TF_LITE_ENSURE_EQ(context, num_inputs, subgraph->inputs().size());
+ TF_LITE_ENSURE_EQ(context, num_outputs, subgraph->outputs().size());
+ }
+
+ bool has_dynamic_output_tensors = false;
+ for (auto* subgraph : {then_subgraph, else_subgraph}) {
+ for (int i = 0; i < num_inputs; ++i) {
+ // The first input of the node is the condition. The indices of the inputs
+ // passed to the subgraphs are offset by 1.
+ const TfLiteTensor* input = GetInput(context, node, i + 1);
+ std::vector<int> dims(input->dims->data,
+ input->dims->data + input->dims->size);
+ subgraph->ResizeInputTensor(i, dims);
+ TfLiteTensor* subgraph_input = subgraph->tensor(subgraph->inputs()[i]);
+ TF_LITE_ENSURE_EQ(context, input->type, subgraph_input->type);
+ }
+ // Note: The `Prepare` function is responsible to run `AllocateTensors` on
+ // both subgraphs. It's intentionally not to break out of the loop when
+ // finding a dynamic output tensor.
+ TF_LITE_ENSURE_OK(context, subgraph->AllocateTensors());
+ has_dynamic_output_tensors |= subgraph->HasDynamicTensors();
+ }
+
+ if (!has_dynamic_output_tensors) {
+ for (int i = 0; i < num_outputs; ++i) {
+ TfLiteTensor* then_output =
+ then_subgraph->tensor(then_subgraph->outputs()[i]);
+ TfLiteTensor* else_output =
+ else_subgraph->tensor(else_subgraph->outputs()[i]);
+ // If the 2 subgraphs have static but different output shapes, the output
+ // tensors of the IF op have dynamic sizes.
+ if (!TfLiteIntArrayEqual(then_output->dims, else_output->dims)) {
+ has_dynamic_output_tensors = true;
+ break;
+ }
+ }
+ }
+
+ for (int i = 0; i < num_outputs; ++i) {
+ TfLiteTensor* output = GetOutput(context, node, i);
+ if (has_dynamic_output_tensors) {
+ SetTensorToDynamic(output);
+ } else {
+ // When there's no dynamic output tensors, the 2 subgraph has exactly
+ // the same static sized outputs.
+ TfLiteTensor* then_output =
+ then_subgraph->tensor(then_subgraph->outputs()[i]);
+ TfLiteIntArray* output_size = TfLiteIntArrayCopy(then_output->dims);
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, output, output_size));
+ }
+ }
+
+ return kTfLiteOk;
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
+
+ const TfLiteTensor* cond = GetInput(context, node, 0);
+ bool cond_value = cond->data.b[0];
+
+ Subgraph* this_subgraph = reinterpret_cast<Subgraph*>(context->impl_);
+ auto* subgraphs = this_subgraph->GetSubgraphs();
+
+ // Currently we copy the input / output between the subgraphs. This isn't
+ // optimized yet.
+ // TODO(b/120234921): Optimize and avoid copying tensors between subgraphs.
+ int active_branch_subgraph_index =
+ cond_value ? op_data->then_subgraph_index : op_data->else_subgraph_index;
+ Subgraph& active_branch_subgraph =
+ *(*subgraphs)[active_branch_subgraph_index];
+ for (int i = 0; i < active_branch_subgraph.inputs().size(); ++i) {
+ const TfLiteTensor* input = GetInput(context, node, i + 1);
+ TfLiteTensor* subgraph_input =
+ active_branch_subgraph.tensor(active_branch_subgraph.inputs()[i]);
+ TF_LITE_ENSURE_EQ(context, input->bytes, subgraph_input->bytes);
+ memcpy(subgraph_input->data.raw, input->data.raw, input->bytes);
+ }
+
+ // Note: It's guaranteed that the subgraphs' `AllocateTensors` are called
+ // in `Prepare`, so we don't need to do it here again.
+ TF_LITE_ENSURE_OK(context, active_branch_subgraph.Invoke());
+
+ bool has_dynamic_output_tensors = false;
+ for (int i = 0; i < node->outputs->size; ++i) {
+ TfLiteTensor* output = GetOutput(context, node, i);
+ if (IsDynamicTensor(output)) {
+ has_dynamic_output_tensors = true;
+ break;
+ }
+ }
+
+ if (has_dynamic_output_tensors) {
+ for (int i = 0; i < node->outputs->size; ++i) {
+ TfLiteTensor* output = GetOutput(context, node, i);
+ TfLiteTensor* subgraph_output =
+ active_branch_subgraph.tensor(active_branch_subgraph.outputs()[i]);
+ TfLiteIntArray* output_size = TfLiteIntArrayCopy(subgraph_output->dims);
+ TF_LITE_ENSURE_OK(context,
+ context->ResizeTensor(context, output, output_size));
+ }
+ }
+
+ for (int i = 0; i < active_branch_subgraph.outputs().size(); ++i) {
+ const TfLiteTensor* subgraph_output =
+ active_branch_subgraph.tensor(active_branch_subgraph.outputs()[i]);
+ TfLiteTensor* output = GetOutput(context, node, i);
+ TF_LITE_ENSURE_EQ(context, output->bytes, subgraph_output->bytes);
+ memcpy(output->data.raw, subgraph_output->data.raw, output->bytes);
+ }
+ return kTfLiteOk;
+}
+
+} // namespace if_kernel
+
+TfLiteRegistration* Register_IF() {
+ static TfLiteRegistration r = {if_kernel::Init, if_kernel::Free,
+ if_kernel::Prepare, if_kernel::Eval};
+ return &r;
+}
+
+} // namespace custom
+} // namespace ops
+} // namespace tflite
diff --git a/tensorflow/lite/kernels/if_test.cc b/tensorflow/lite/kernels/if_test.cc
new file mode 100644
index 0000000..a1460cd
--- /dev/null
+++ b/tensorflow/lite/kernels/if_test.cc
@@ -0,0 +1,120 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <gtest/gtest.h>
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
+#include "tensorflow/lite/interpreter.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/kernels/register.h"
+#include "tensorflow/lite/kernels/subgraph_test_util.h"
+#include "tensorflow/lite/kernels/test_util.h"
+#include "tensorflow/lite/model.h"
+
+namespace tflite {
+
+using subgraph_test_util::BuildAddSubgraph;
+using subgraph_test_util::BuildIfSubgraph;
+using subgraph_test_util::BuildMulSubgraph;
+using subgraph_test_util::BuildPadSubgraph;
+using subgraph_test_util::CheckIntTensor;
+using subgraph_test_util::FillIntTensor;
+
+namespace {
+
+// A simple test that performs `ADD` if condition is true, and `MUL` otherwise.
+// The computation is: `cond ? a + b : a * b`.
+class SimpleIfTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ interpreter_.reset(new Interpreter);
+ interpreter_->AddSubgraphs(2);
+ BuildAddSubgraph(interpreter_->subgraph(1));
+ BuildMulSubgraph(interpreter_->subgraph(2));
+ BuildIfSubgraph(&interpreter_->primary_subgraph());
+
+ interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1});
+ interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {2});
+ interpreter_->ResizeInputTensor(interpreter_->inputs()[2], {1, 2});
+ ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
+
+ FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {5, 7});
+ FillIntTensor(interpreter_->tensor(interpreter_->inputs()[2]), {1, 2});
+ }
+ std::unique_ptr<Interpreter> interpreter_;
+};
+
+TEST_F(SimpleIfTest, TestIfTrue) {
+ interpreter_->typed_input_tensor<bool>(0)[0] = true;
+ ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
+ TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
+ CheckIntTensor(output, {1, 2}, {6, 9});
+}
+
+TEST_F(SimpleIfTest, TestIfFalse) {
+ interpreter_->typed_input_tensor<bool>(0)[0] = false;
+ ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
+ TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
+ CheckIntTensor(output, {1, 2}, {5, 14});
+}
+
+// Test IF op using subgraphs with dynamically sized outputs.
+// The computation is: `cond ? a + b : pad(a, b)`.
+class DynamicSubgraphIfTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ interpreter_.reset(new Interpreter);
+ interpreter_->AddSubgraphs(2);
+ BuildAddSubgraph(interpreter_->subgraph(1));
+ BuildPadSubgraph(interpreter_->subgraph(2));
+ BuildIfSubgraph(&interpreter_->primary_subgraph());
+
+ interpreter_->ResizeInputTensor(interpreter_->inputs()[0], {1});
+ interpreter_->ResizeInputTensor(interpreter_->inputs()[1], {2});
+ interpreter_->ResizeInputTensor(interpreter_->inputs()[2], {1, 2});
+ ASSERT_EQ(interpreter_->AllocateTensors(), kTfLiteOk);
+
+ FillIntTensor(interpreter_->tensor(interpreter_->inputs()[1]), {5, 7});
+ FillIntTensor(interpreter_->tensor(interpreter_->inputs()[2]), {1, 2});
+ }
+ std::unique_ptr<Interpreter> interpreter_;
+};
+
+TEST_F(DynamicSubgraphIfTest, TestIfTrue) {
+ interpreter_->typed_input_tensor<bool>(0)[0] = true;
+ ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
+ TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
+ // Even if the true branch has a static type output, the output of the
+ // if op is dynamic because the other branch has dynamic output.
+ EXPECT_TRUE(IsDynamicTensor(output));
+ CheckIntTensor(output, {1, 2}, {6, 9});
+}
+
+TEST_F(DynamicSubgraphIfTest, TestIfFalse) {
+ interpreter_->typed_input_tensor<bool>(0)[0] = false;
+ ASSERT_EQ(interpreter_->Invoke(), kTfLiteOk);
+ TfLiteTensor* output = interpreter_->tensor(interpreter_->outputs()[0]);
+ // The false branch has dynamic output.
+ EXPECT_TRUE(IsDynamicTensor(output));
+ CheckIntTensor(output, {5}, {0, 5, 7, 0, 0});
+}
+
+} // namespace
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/lite/kernels/internal/BUILD b/tensorflow/lite/kernels/internal/BUILD
index 9d982a8..97b3225 100644
--- a/tensorflow/lite/kernels/internal/BUILD
+++ b/tensorflow/lite/kernels/internal/BUILD
@@ -307,6 +307,7 @@
"reference/depthwiseconv_uint8.h",
"reference/fully_connected.h",
"reference/integer_ops/conv.h",
+ "reference/integer_ops/depthwise_conv.h",
"reference/integer_ops/dequantize.h",
"reference/integer_ops/pooling.h",
"reference/integer_ops/softmax.h",
@@ -590,6 +591,7 @@
":reference_base",
":test_util",
":types",
+ "@com_google_absl//absl/strings",
"@com_google_googletest//:gtest_main",
],
)
diff --git a/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc b/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc
index 3682499..b396e62 100644
--- a/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc
+++ b/tensorflow/lite/kernels/internal/depthwiseconv_quantized_test.cc
@@ -19,6 +19,7 @@
#include <cstdlib>
#include <iterator>
#include <limits>
+#include <string>
#include <vector>
#include <gtest/gtest.h>
@@ -26,6 +27,7 @@
#include "tensorflow/lite/kernels/internal/types.h"
#define ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK
+#include "absl/strings/substitute.h"
#include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h"
#include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h"
#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
@@ -33,26 +35,71 @@
namespace tflite {
namespace {
-enum class ForceKernelInvocation {
- // Run all tests against kUseStandardEntry even if also testing another
- // kernel, since we need to be sure that the main DepthwiseConv() function in
- // optimized_ops.h dispatches to a correctly-executing kernel.
- kNone = 0, // The "default" option: use the normal DepthwiseConv
- // kernel (entry) function.
- kUseGenericKernel,
- kUseNeon3x3, // 3x3 kernel that uses NEON when available.
- kUseNeon3x3DotProduct, // 3x3 kernel that uses dot-product enabled NEON when
- // available.
+using ::testing::Bool;
+using ::testing::Values;
+
+// Currently, this is used in place of a Boolean "is symmetric?".
+enum class ParamsSpecialization {
+ kNone = 0,
+ kSymmetric, // Symmetric quantization: zero represented by 128.
+};
+
+static constexpr int kSymmetricZeroPoint = 128;
+
+// Extend coverage distribution in a specific aspect, either explicitly chosen
+// or randomly chosen as in a mixture distribution.
+enum class CoverageExtension {
+ kNone = 0,
+ kLargeHeights = 1,
+ kLargeWidths = 2,
+ kNumOptions
+};
+
+// The TestParam structure below is the preferred parameterization of tests. A
+// tuple version is defined in order to support value-parameterized tests.
+typedef std::tuple<DepthwiseConvInvocation, int, bool, bool, bool,
+ DepthwiseConvOutputRounding, bool>
+ TestParamTuple;
+
+struct TestParam {
+ TestParam() = default;
+
+ explicit TestParam(TestParamTuple param_tuple)
+ : forced_invocation(::testing::get<0>(param_tuple)),
+ tests_to_run(::testing::get<1>(param_tuple)),
+ test_stride(::testing::get<2>(param_tuple)),
+ test_pad(::testing::get<3>(param_tuple)),
+ test_depth_multiplier(::testing::get<4>(param_tuple)),
+ output_rounding(::testing::get<5>(param_tuple)),
+ loose_tolerance(::testing::get<6>(param_tuple)) {}
+
+ static std::string TestNameSuffix(
+ const ::testing::TestParamInfo<TestParamTuple>& info) {
+ const TestParam param(info.param);
+ return absl::Substitute("invocation_$0_stride_$1_pad_$2_depth_mult_$3",
+ static_cast<int>(param.forced_invocation),
+ param.test_stride, param.test_pad,
+ param.test_depth_multiplier);
+ }
+
+ DepthwiseConvInvocation forced_invocation = DepthwiseConvInvocation::kNone;
+ int tests_to_run = 0;
+ bool test_stride = false;
+ bool test_pad = false;
+ bool test_depth_multiplier = false;
+ DepthwiseConvOutputRounding output_rounding =
+ DepthwiseConvOutputRounding::kNone;
+ bool loose_tolerance = false;
};
inline void DispatchDepthwiseConv(
- ForceKernelInvocation forced_invocation, const DepthwiseParams& params,
+ const TestParam& test_param, const DepthwiseParams& params,
const RuntimeShape& input_shape, const uint8* input_data,
const RuntimeShape& filter_shape, const uint8* filter_data,
const RuntimeShape& bias_shape, const int32* bias_data,
const RuntimeShape& output_shape, uint8* output_data) {
- switch (forced_invocation) {
- case ForceKernelInvocation::kUseNeon3x3: {
+ switch (test_param.forced_invocation) {
+ case DepthwiseConvInvocation::kUseNeon3x3: {
// Enable for arm64 except for the Nvidia Linux 4 Tegra (L4T) running on
// Jetson TX-2. This compiler does not support the offsetof() macro.
#if defined(__aarch64__) && !defined(GOOGLE_L4T)
@@ -74,10 +121,10 @@
ASSERT_TRUE(basic_3x3_kernel_supported)
<< "pad_width = " << params.padding_values.width
<< " pad_height = " << params.padding_values.height
- << " input_width = " << input_shape.Dims(1)
- << " input_height = " << input_shape.Dims(2)
- << " output_width = " << output_shape.Dims(1)
- << " output_height = " << output_shape.Dims(2);
+ << " input_width = " << input_shape.Dims(2)
+ << " input_height = " << input_shape.Dims(1)
+ << " output_width = " << output_shape.Dims(2)
+ << " output_height = " << output_shape.Dims(1);
// Call kernel optimized for depthwise convolutions using 3x3 filters.
optimized_ops::DepthwiseConv3x3Filter(
@@ -88,56 +135,24 @@
break;
#endif
}
- case ForceKernelInvocation::kUseNeon3x3DotProduct: {
-// Enable for arm64 except for the Nvidia Linux 4 Tegra (L4T) running on
-// Jetson TX-2. This compiler does not support the offsetof() macro.
-#if defined(__ARM_FEATURE_DOTPROD) && defined(__aarch64__) && \
- !defined(GOOGLE_L4T)
- using optimized_ops::DotProduct3x3KernelType;
- DotProduct3x3KernelType kernel_type =
- optimized_ops::CategorizeDotProductKernel(params);
- switch (kernel_type) {
- case DotProduct3x3KernelType::kPlain:
- // TODO(b/118430534): Implement optimized kernel.
- optimized_ops::DepthwiseConv3x3Filter(
- params, input_shape, input_data, filter_shape, filter_data,
- bias_shape, bias_data, output_shape, output_data);
- return;
- case DotProduct3x3KernelType::kWithDepthMultiplication:
- // TODO(b/118430338): Implement optimized kernel.
- optimized_ops::DepthwiseConvGeneral(
- params, input_shape, input_data, filter_shape, filter_data,
- bias_shape, bias_data, output_shape, output_data);
- return;
- case DotProduct3x3KernelType::kWithPad0Stride2:
- // TODO(b/118430338): Implement optimized kernel.
- optimized_ops::DepthwiseConv3x3Filter(
- params, input_shape, input_data, filter_shape, filter_data,
- bias_shape, bias_data, output_shape, output_data);
- return;
- case DotProduct3x3KernelType::kWithPad1Stride1:
- // TODO(b/118430338): Implement optimized kernel.
- optimized_ops::DepthwiseConvGeneral(
- params, input_shape, input_data, filter_shape, filter_data,
- bias_shape, bias_data, output_shape, output_data);
- return;
- case DotProduct3x3KernelType::kNone:
- default:
- break;
- }
-#endif
+ case DepthwiseConvInvocation::kUseNeon3x3DotProduct:
+ case DepthwiseConvInvocation::kUseCModel3x3DotProduct:
+ case DepthwiseConvInvocation::kUseUnwound3x3DotProduct:
+ case DepthwiseConvInvocation::kUseIntrinsics3x3DotProduct:
+ // TODO(b/118426582) Placeholder for future dispatches.
break;
- }
- case ForceKernelInvocation::kUseGenericKernel: {
+ case DepthwiseConvInvocation::kUseGenericKernel: {
optimized_ops::DepthwiseConvGeneral(params, input_shape, input_data,
filter_shape, filter_data, bias_shape,
bias_data, output_shape, output_data);
return;
}
- case ForceKernelInvocation::kNone:
+ case DepthwiseConvInvocation::kNone:
default:
break;
}
+ EXPECT_EQ(test_param.forced_invocation, DepthwiseConvInvocation::kNone)
+ << "TODO(b/118426582) requested kernel was not invoked / available yet";
optimized_ops::DepthwiseConv(params, input_shape, input_data, filter_shape,
filter_data, bias_shape, bias_data, output_shape,
output_data);
@@ -145,7 +160,7 @@
// Runs the DepthwiseConv and compares against the reference implementation.
int TestOneDepthwiseConvWithGivenOutputShift(
- ForceKernelInvocation forced_invocation, const std::uint8_t* input_data,
+ const TestParam& test_param, const std::uint8_t* input_data,
const RuntimeShape& input_shape, std::int32_t input_offset,
const std::uint8_t* filter_data, const RuntimeShape& filter_shape,
std::int32_t filter_offset, const std::int32_t* bias_data,
@@ -174,10 +189,31 @@
op_params.output_offset = output_offset;
op_params.output_multiplier = output_multiplier;
op_params.output_shift = -output_shift;
- reference_ops::DepthwiseConv(op_params, input_shape, input_data, filter_shape,
- filter_data, bias_shape, bias_data, output_shape,
- reference_output_data.data());
- DispatchDepthwiseConv(forced_invocation, op_params, input_shape, input_data,
+ switch (test_param.output_rounding) {
+ case DepthwiseConvOutputRounding::kUpward:
+ reference_ops::DepthwiseConvBasicKernel<
+ DepthwiseConvOutputRounding::kAwayFromZero>::Run(op_params,
+ input_shape,
+ input_data,
+ filter_shape,
+ filter_data,
+ bias_shape,
+ bias_data,
+ output_shape,
+ reference_output_data
+ .data());
+ break;
+ case DepthwiseConvOutputRounding::kAwayFromZero:
+ reference_ops::DepthwiseConv(
+ op_params, input_shape, input_data, filter_shape, filter_data,
+ bias_shape, bias_data, output_shape, reference_output_data.data());
+ break;
+ case DepthwiseConvOutputRounding::kNone:
+ default:
+ EXPECT_NE(test_param.output_rounding, DepthwiseConvOutputRounding::kNone);
+ break;
+ }
+ DispatchDepthwiseConv(test_param, op_params, input_shape, input_data,
filter_shape, filter_data, bias_shape, bias_data,
output_shape, output_data.data());
int saturated_min = 0;
@@ -201,15 +237,26 @@
const float mean_diff = static_cast<float>(sum_diff) / output_buffer_size;
const float mean_abs_diff =
static_cast<float>(sum_abs_diff) / output_buffer_size;
+
+ constexpr int diff_mean_tolerance = 1;
+ constexpr int diff_median_tolerance = 0;
+ // The tolerance that we apply to means is tight, but we allow for a rounding
+ // difference in one pixel, and loosen by another 1% for float comparison.
+ const float mean_tolerance =
+ std::max(1e-5f, 1.01f * 2.f / output_buffer_size *
+ std::sqrt(1.f * depth_multiplier));
+
// Normally we should require bit-for-bit exact results. Unfortunately a bug
// in the Intel arm_neon_sse.h translation header that we use for x86 tests
- // causes 1-bit inaccuracy in
- // the vqrdmulh_n_s32 intrinsic, which causes off-by-1 errors in quantized
- // DepthwiseConv ops. So we have to live with a few off-by-one errors for now,
- // yet still ensure that no more than a small minority of values are wrong.
- EXPECT_TRUE(std::abs(mean_diff) < 1e-5f && mean_abs_diff < 1e-5f &&
- std::abs(median_diff) == 0 && std::abs(min_diff) <= 1 &&
- std::abs(max_diff) <= 1);
+ // causes 1-bit inaccuracy in the vqrdmulh_n_s32 intrinsic, which causes
+ // off-by-1 errors in quantized DepthwiseConv ops. So we have to live with a
+ // few off-by-one errors for now, yet still ensure that no more than a small
+ // minority of values are wrong.
+ EXPECT_LT(std::abs(mean_diff), mean_tolerance);
+ EXPECT_LT(mean_abs_diff, mean_tolerance);
+ EXPECT_LE(std::abs(median_diff), diff_median_tolerance);
+ EXPECT_LE(std::abs(min_diff), diff_mean_tolerance);
+ EXPECT_LE(std::abs(max_diff), diff_mean_tolerance);
if (saturated_min > 2 * saturated_max) {
return -1;
}
@@ -221,13 +268,12 @@
// The point of this function is that we can't practically know which
// output_shift value to pass to test DepthwiseConv. It's not easy to guess (we
-// could do some
-// statistics for large size, but they would be fragile at smaller sizes), and
-// guessing wrong would mean that all the values get saturated so the test
-// becomes
-// vacuous. So we just bisect our way to reasonable output_shift values.
+// could do some statistics for large size, but they would be fragile at smaller
+// sizes), and guessing wrong would mean that all the values get saturated so
+// the test becomes vacuous. So we just bisect our way to reasonable
+// output_shift values.
void TestOneDepthwiseConvBisectOutputShift(
- ForceKernelInvocation forced_invocation, const std::uint8_t* input_data,
+ const TestParam& test_param, const std::uint8_t* input_data,
const RuntimeShape& input_shape, std::int32_t input_offset,
const std::uint8_t* filter_data, const RuntimeShape& filter_shape,
std::int32_t filter_offset, const std::int32_t* bias_data,
@@ -242,7 +288,7 @@
int output_shift_bisect_midpoint =
(output_activation_bisect_start + output_activation_bisect_end) / 2;
int bisect_result = TestOneDepthwiseConvWithGivenOutputShift(
- forced_invocation, input_data, input_shape, input_offset, filter_data,
+ test_param, input_data, input_shape, input_offset, filter_data,
filter_shape, filter_offset, bias_data, bias_shape, stride, padding_type,
pad_width, pad_height, depth_multiplier, output_offset, output_multiplier,
output_shift_bisect_midpoint, output_activation_min,
@@ -269,7 +315,7 @@
? output_activation_bisect_end
: output_shift_bisect_midpoint;
TestOneDepthwiseConvBisectOutputShift(
- forced_invocation, input_data, input_shape, input_offset, filter_data,
+ test_param, input_data, input_shape, input_offset, filter_data,
filter_shape, filter_offset, bias_data, bias_shape, stride, padding_type,
pad_width, pad_height, depth_multiplier, output_offset, output_multiplier,
new_output_activation_bisect_start, new_output_activation_bisect_end,
@@ -277,7 +323,7 @@
}
void TestOneDepthwiseConv(
- ForceKernelInvocation forced_invocation, const std::uint8_t* input_data,
+ const TestParam& test_param, const std::uint8_t* input_data,
const RuntimeShape& input_shape, std::int32_t input_offset,
const std::uint8_t* filter_data, const RuntimeShape& filter_shape,
std::int32_t filter_offset, const std::int32_t* bias_data,
@@ -287,13 +333,14 @@
std::int32_t output_activation_min, std::int32_t output_activation_max,
const RuntimeShape& output_shape) {
TestOneDepthwiseConvBisectOutputShift(
- forced_invocation, input_data, input_shape, input_offset, filter_data,
+ test_param, input_data, input_shape, input_offset, filter_data,
filter_shape, filter_offset, bias_data, bias_shape, stride, padding_type,
pad_width, pad_height, depth_multiplier, output_offset, output_multiplier,
0, 32, output_activation_min, output_activation_max, output_shape);
}
-bool TryTestDepthwiseConv(ForceKernelInvocation forced_invocation, int batch,
+bool TryTestDepthwiseConv(const TestParam& test_param,
+ ParamsSpecialization params_specialization, int batch,
int input_depth, int input_width, int input_height,
int filter_width, int filter_height,
int depth_multiplier, int stride,
@@ -318,9 +365,12 @@
}
const std::int32_t output_multiplier =
UniformRandomInt(1 << 29, std::numeric_limits<std::int32_t>::max());
- const std::int32_t input_offset = UniformRandomInt(-256, 0);
- const std::int32_t filter_offset = UniformRandomInt(-256, 0);
- const std::int32_t output_offset = UniformRandomInt(-256, 0);
+ std::int32_t filter_offset = -kSymmetricZeroPoint;
+ if (params_specialization != ParamsSpecialization::kSymmetric) {
+ filter_offset = UniformRandomInt(-255, 0);
+ }
+ const std::int32_t input_offset = UniformRandomInt(-255, 0);
+ const std::int32_t output_offset = UniformRandomInt(0, 255);
RuntimeShape input_shape_inference(
{batch, input_height, input_width, input_depth});
RuntimeShape output_shape_inference;
@@ -343,7 +393,7 @@
FillRandom(&filter_data);
FillRandom(&bias_data, -10000, 10000);
TestOneDepthwiseConv(
- forced_invocation, input_data.data(), input_shape_inference, input_offset,
+ test_param, input_data.data(), input_shape_inference, input_offset,
filter_data.data(), filter_shape_inference, filter_offset,
bias_data.data(), bias_shape_inference, stride, padding_type, pad_width,
pad_height, depth_multiplier, output_offset, output_multiplier,
@@ -355,7 +405,8 @@
// be legal. If they're not legal, it returns false. If they're legal,
// it runs the DepthwiseConv test and returns true. This allows the caller
// to loop until a test has been run.
-bool TryTestOneDepthwiseConv(ForceKernelInvocation forced_invocation) {
+bool TryTestOneDepthwiseConv(const TestParam& test_param,
+ ParamsSpecialization params_specialization) {
// We have to pick a lot of positive values, where we are particularly
// interested in small values because they are most likely to be special
// cases in optimized implementations, and secondarily because they allow
@@ -375,13 +426,14 @@
UniformRandomInt(0, 1) ? PaddingType::kSame : PaddingType::kValid;
return TryTestDepthwiseConv(
- forced_invocation, batch, input_depth, input_width, input_height,
- filter_width, filter_height, depth_multiplier, stride,
+ test_param, params_specialization, batch, input_depth, input_width,
+ input_height, filter_width, filter_height, depth_multiplier, stride,
dilation_width_factor, dilation_height_factor, padding_type);
}
// Tests parameters for the 3x3 filter kernel.
-bool TryTestOneDepthwiseConv3x3Filter(ForceKernelInvocation forced_invocation) {
+bool TryTestOneDepthwiseConv3x3Filter(
+ const TestParam& test_param, ParamsSpecialization params_specialization) {
const int batch = ExponentialRandomPositiveInt(0.9f, 3, 20);
const int input_depth = 8 * ExponentialRandomPositiveInt(0.9f, 10, 50);
int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200);
@@ -397,7 +449,7 @@
UniformRandomInt(0, 1) ? PaddingType::kSame : PaddingType::kValid;
// Adjust for, or reject, special cases.
- if (forced_invocation != ForceKernelInvocation::kNone) {
+ if (test_param.forced_invocation != DepthwiseConvInvocation::kNone) {
// With stride == 2 and SAME, padding width and height are the left and top
// padding amounts. When there is an even input dimension, padding + 1 is
// required on the right / bottom. This is not handled by these kernels, so
@@ -416,59 +468,77 @@
}
return TryTestDepthwiseConv(
- forced_invocation, batch, input_depth, input_width, input_height,
- filter_width, filter_height, depth_multiplier, stride,
+ test_param, params_specialization, batch, input_depth, input_width,
+ input_height, filter_width, filter_height, depth_multiplier, stride,
dilation_width_factor, dilation_height_factor, padding_type);
}
// Tests with parameters suited to dot-product-NEON 3x3 filter kernels.
-bool TryTestOneNeonDot3x3(ForceKernelInvocation forced_invocation,
- bool test_stride, bool test_pad,
- bool test_depth_multiplier) {
+bool TryTestOneNeonDot3x3(const TestParam& test_param,
+ ParamsSpecialization params_specialization) {
+ const CoverageExtension coverage_extension = static_cast<CoverageExtension>(
+ UniformRandomInt(0, static_cast<int>(CoverageExtension::kNumOptions)));
+
const int batch = 1;
- const int input_depth = test_depth_multiplier
+ const int input_depth = test_param.test_depth_multiplier
? 1
- : 8 * ExponentialRandomPositiveInt(0.9f, 10, 50);
- const int input_width = ExponentialRandomPositiveInt(0.9f, 20, 200);
- const int input_height = ExponentialRandomPositiveInt(0.9f, 20, 200);
+ : 8 * ExponentialRandomPositiveInt(0.9f, 3, 50);
+ const int input_width = coverage_extension == CoverageExtension::kLargeWidths
+ ? ExponentialRandomPositiveInt(0.9f, 50, 200)
+ : ExponentialRandomPositiveInt(0.9f, 20, 60);
+ const int input_height =
+ coverage_extension == CoverageExtension::kLargeHeights
+ ? ExponentialRandomPositiveInt(0.9f, 50, 200)
+ : ExponentialRandomPositiveInt(0.9f, 20, 60);
const int filter_width = 3;
const int filter_height = 3;
const int depth_multiplier =
- test_depth_multiplier ? 8 * ExponentialRandomPositiveInt(0.8f, 1, 6) : 1;
- const int stride = test_stride ? 2 : 1;
+ test_param.test_depth_multiplier
+ ? 8 * ExponentialRandomPositiveInt(0.2f, 1, 9)
+ : 1;
+ const int stride = test_param.test_stride ? 2 : 1;
// We don't support dilations in the 3x3 filter.
const int dilation_width_factor = 1;
const int dilation_height_factor = 1;
- const auto padding_type = test_pad ? PaddingType::kSame : PaddingType::kValid;
+ const auto padding_type =
+ test_param.test_pad ? PaddingType::kSame : PaddingType::kValid;
return TryTestDepthwiseConv(
- forced_invocation, batch, input_depth, input_width, input_height,
- filter_width, filter_height, depth_multiplier, stride,
+ test_param, params_specialization, batch, input_depth, input_width,
+ input_height, filter_width, filter_height, depth_multiplier, stride,
dilation_width_factor, dilation_height_factor, padding_type);
}
-void TestOneDepthwiseConv(ForceKernelInvocation forced_invocation) {
- while (!TryTestOneDepthwiseConv(forced_invocation)) {
+void TestOneDepthwiseConv(DepthwiseConvInvocation forced_invocation,
+ DepthwiseConvOutputRounding output_rounding) {
+ TestParam test_param;
+ test_param.forced_invocation = forced_invocation;
+ test_param.output_rounding = output_rounding;
+ while (!TryTestOneDepthwiseConv(test_param, ParamsSpecialization::kNone)) {
}
}
-void TestOneDepthwiseConv3x3Filter(ForceKernelInvocation forced_invocation) {
- while (!TryTestOneDepthwiseConv3x3Filter(forced_invocation)) {
+void TestOneDepthwiseConv3x3Filter(
+ DepthwiseConvInvocation forced_invocation,
+ DepthwiseConvOutputRounding output_rounding) {
+ TestParam test_param;
+ test_param.forced_invocation = forced_invocation;
+ test_param.output_rounding = output_rounding;
+ while (!TryTestOneDepthwiseConv3x3Filter(test_param,
+ ParamsSpecialization::kNone)) {
}
}
-void TestOneNeonDot3x3(ForceKernelInvocation forced_invocation,
- bool test_stride, bool test_pad,
- bool test_depth_multiplier) {
- while (!TryTestOneNeonDot3x3(forced_invocation, test_stride, test_pad,
- test_depth_multiplier)) {
+void TestOneNeonDot3x3(const TestParam& test_param) {
+ while (!TryTestOneNeonDot3x3(test_param, ParamsSpecialization::kSymmetric)) {
}
}
TEST(TestDepthwiseConv, TestDepthwiseConv) {
const int kTestsToRun = 10 * 1000;
for (int i = 0; i < kTestsToRun; i++) {
- TestOneDepthwiseConv(ForceKernelInvocation::kNone);
+ TestOneDepthwiseConv(DepthwiseConvInvocation::kNone,
+ DepthwiseConvOutputRounding::kAwayFromZero);
}
}
@@ -476,69 +546,78 @@
TEST(TestDepthwiseConv, TestGenericKernel) {
const int kTestsToRun = 10 * 1000;
for (int i = 0; i < kTestsToRun; i++) {
- TestOneDepthwiseConv(ForceKernelInvocation::kUseGenericKernel);
+ TestOneDepthwiseConv(DepthwiseConvInvocation::kUseGenericKernel,
+ DepthwiseConvOutputRounding::kAwayFromZero);
}
}
TEST(TestDepthwiseConv, TestKernel3x3Filter) {
const int kTestsToRun = 1000;
for (int i = 0; i < kTestsToRun; i++) {
- TestOneDepthwiseConv3x3Filter(ForceKernelInvocation::kNone);
+ TestOneDepthwiseConv3x3Filter(DepthwiseConvInvocation::kNone,
+ DepthwiseConvOutputRounding::kAwayFromZero);
}
}
-// While the 3x3 coverage test is primarily targeted at specialized kernels, we
-// also run it against the generic kernel, optionally with fewer invocations.
+// While 3x3 coverage tests are primarily targeted at specialized kernels, we
+// also run it against the generic kernel.
TEST(TestDepthwiseConv, TestGenericKernel3x3Filter) {
- const int kTestsToRun = 1000;
+ const int kTestsToRun = 100;
for (int i = 0; i < kTestsToRun; i++) {
- TestOneDepthwiseConv3x3Filter(ForceKernelInvocation::kUseGenericKernel);
+ TestOneDepthwiseConv3x3Filter(DepthwiseConvInvocation::kUseGenericKernel,
+ DepthwiseConvOutputRounding::kAwayFromZero);
}
}
+#if defined(__aarch64__) && !defined(GOOGLE_L4T)
TEST(TestDepthwiseConv, TestNeon3x3Filter) {
const int kTestsToRun = 3 * 1000;
for (int i = 0; i < kTestsToRun; i++) {
- TestOneDepthwiseConv3x3Filter(ForceKernelInvocation::kUseNeon3x3);
+ TestOneDepthwiseConv3x3Filter(DepthwiseConvInvocation::kUseNeon3x3,
+ DepthwiseConvOutputRounding::kAwayFromZero);
+ }
+}
+#endif
+
+class DepthwiseConvTest : public ::testing::TestWithParam<TestParamTuple> {};
+
+TEST_P(DepthwiseConvTest, NeonDot3x3) {
+ const TestParam param(GetParam());
+ for (int i = 0; i < param.tests_to_run; i++) {
+ TestOneNeonDot3x3(param);
}
}
-// No stride, no depth multiplier, no pad.
-TEST(TestDepthwiseConv, TestNeonDot3x3Plain) {
- const int kTestsToRun = 3 * 1000;
- for (int i = 0; i < kTestsToRun; i++) {
- TestOneNeonDot3x3(ForceKernelInvocation::kUseNeon3x3DotProduct,
- /*test_stride=*/false, /*test_pad=*/false,
- /*test_depth_multiplier=*/false);
- }
-}
+#if defined(__aarch64__) && !defined(GOOGLE_L4T)
+INSTANTIATE_TEST_SUITE_P(
+ Neon3x3Kernel, DepthwiseConvTest,
+ testing::Combine(
+ Values(DepthwiseConvInvocation::kUseNeon3x3), // forced_invocation
+ Values(1000), // tests_to_run
+ Bool(), // test_stride
+ Values(false), // test_pad
+ Values(false), // test_depth_multiplier
+ Values(DepthwiseConvOutputRounding::kAwayFromZero), // output_rounding
+ Values(false) // loose_tolerance
+ ),
+ TestParam::TestNameSuffix);
+#endif
-TEST(TestDepthwiseConv, TestNeonDot3x3DepthMultiplier) {
- const int kTestsToRun = 3 * 1000;
- for (int i = 0; i < kTestsToRun; i++) {
- TestOneNeonDot3x3(ForceKernelInvocation::kUseNeon3x3DotProduct,
- /*test_stride=*/false, /*test_pad=*/false,
- /*test_depth_multiplier=*/true);
- }
-}
-
-TEST(TestDepthwiseConv, TestNeonDot3x3Stride2) {
- const int kTestsToRun = 3 * 1000;
- for (int i = 0; i < kTestsToRun; i++) {
- TestOneNeonDot3x3(ForceKernelInvocation::kUseNeon3x3DotProduct,
- /*test_stride=*/true, /*test_pad=*/false,
- /*test_depth_multiplier=*/false);
- }
-}
-
-TEST(TestDepthwiseConv, TestNeonDot3x3Pad1) {
- const int kTestsToRun = 3 * 1000;
- for (int i = 0; i < kTestsToRun; i++) {
- TestOneNeonDot3x3(ForceKernelInvocation::kUseNeon3x3DotProduct,
- /*test_stride=*/false, /*test_pad=*/true,
- /*test_depth_multiplier=*/false);
- }
-}
+// While 3x3 coverage tests are primarily targeted at specialized kernels, we
+// also run it against the generic kernel.
+INSTANTIATE_TEST_SUITE_P(
+ GenericKernel, DepthwiseConvTest,
+ testing::Combine(
+ Values(
+ DepthwiseConvInvocation::kUseGenericKernel), // forced_invocation
+ Values(100), // tests_to_run
+ Bool(), // test_stride
+ Bool(), // test_pad
+ Bool(), // test_depth_multiplier
+ Values(DepthwiseConvOutputRounding::kUpward), // output_rounding
+ Values(false) // loose_tolerance
+ ),
+ TestParam::TestNameSuffix);
} // namespace
} // namespace tflite
diff --git a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h
index 0f4226d..1362949 100644
--- a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h
+++ b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8.h
@@ -19,6 +19,7 @@
#include "public/gemmlowp.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h"
+#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
diff --git a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h
index 5859bca..b7993c3 100644
--- a/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h
+++ b/tensorflow/lite/kernels/internal/optimized/depthwiseconv_uint8_3x3_filter.h
@@ -18,6 +18,7 @@
#include "fixedpoint/fixedpoint.h"
#include "public/gemmlowp.h"
#include "tensorflow/lite/kernels/internal/common.h"
+#include "tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
@@ -27,33 +28,33 @@
enum class DotProduct3x3KernelType {
kNone = 0, // Parameter combination is not supported for dot product kernels.
kPlain,
- kWithDepthMultiplication,
- kWithPad0Stride2,
- kWithPad1Stride1,
+ kWithDepthMultiplicationStride1,
+ kWithDepthMultiplicationStride2,
+ kStride2,
};
inline DotProduct3x3KernelType CategorizeDotProductKernel(
const DepthwiseParams& params) {
- const int padding = params.padding_values.width;
+ const int padding =
+ std::max(params.padding_values.width, params.padding_values.height);
const int stride = params.stride_width;
- if (padding != params.padding_values.height ||
- stride != params.stride_height) {
+ if (stride != params.stride_height || padding > 1) {
return DotProduct3x3KernelType::kNone;
}
if (params.depth_multiplier == 1) {
- if (padding == 0 && stride == 1) {
+ if (stride == 1) {
return DotProduct3x3KernelType::kPlain;
- } else if (padding == 0 && stride == 2) {
- return DotProduct3x3KernelType::kWithPad0Stride2;
- } else if (padding == 1 && stride == 1) {
- return DotProduct3x3KernelType::kWithPad1Stride1;
+ } else if (stride == 2) {
+ return DotProduct3x3KernelType::kStride2;
} else {
return DotProduct3x3KernelType::kNone;
}
} else {
- if (padding == 0 && stride == 1) {
- return DotProduct3x3KernelType::kWithDepthMultiplication;
+ if (stride == 1) {
+ return DotProduct3x3KernelType::kWithDepthMultiplicationStride1;
+ } else if (stride == 2) {
+ return DotProduct3x3KernelType::kWithDepthMultiplicationStride2;
} else {
return DotProduct3x3KernelType::kNone;
}
diff --git a/tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h b/tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h
index 002444b..7cc5679 100644
--- a/tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h
+++ b/tensorflow/lite/kernels/internal/reference/depthwiseconv_uint8.h
@@ -23,90 +23,170 @@
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
+
+// Used in tests and template parameters to control which version of depthwise
+// convolution is called. Primarily for reference code, and specializations
+// forced in tests.
+enum class DepthwiseConvInvocation {
+ // Run all tests against kUseStandardEntry even if also testing another
+ // kernel, since we need to be sure that the main DepthwiseConv() function in
+ // optimized_ops.h dispatches to a correctly-executing kernel.
+ kNone = 0, // The "default" option: use the normal
+ // DepthwiseConv kernel (entry) function.
+ kUseGenericKernel, // Forced use of generic kernel.
+ kUseNeon3x3, // 3x3 kernel that uses NEON when available.
+ kUseNeon3x3DotProduct, // 3x3 kernel that uses dot-product enabled NEON
+ // when available.
+ kUseCModel3x3DotProduct, // 3x3 kernel, reference C model that is intended
+ // to match overall design NEON code.
+ kUseUnwound3x3DotProduct, // 3x3 kernel, reference C model with unwound loops
+ // and some arrays.
+ kUseIntrinsics3x3DotProduct, // 3x3 kernel using NEON intrinsics.
+};
+
+// Category of depthwise convolution output rounding.
+enum class DepthwiseConvOutputRounding {
+ kNone = 0, // Invalid: specific method must be specified.
+ kAwayFromZero, // Original method: exact halves rounded away from zero.
+ kUpward, // Halves towards +infinity: adds 0.5 before truncate.
+ // This is where a future kNearestEven would be placed.
+};
+
+// Category of depthwise convolution depth multiplication.
+enum class DepthwiseConvDepthMultiplication {
+ kNoMultiplication = 0, // Depth multiplier = 1.
+ kUnitInputDepth, // Input depth = 1, output depth = depth multiplier.
+};
+
namespace reference_ops {
+template <DepthwiseConvOutputRounding output_rounding>
+inline int32 DepthwiseConvRound(int32 x, int32 quantized_multiplier,
+ int shift) {
+ TFLITE_DCHECK_NE(output_rounding, DepthwiseConvOutputRounding::kNone);
+ return MultiplyByQuantizedMultiplier(x, quantized_multiplier, shift);
+}
+
+template <>
+inline int32 DepthwiseConvRound<DepthwiseConvOutputRounding::kAwayFromZero>(
+ int32 x, int32 quantized_multiplier, int shift) {
+ return MultiplyByQuantizedMultiplier(x, quantized_multiplier, shift);
+}
+
+template <>
+inline int32 DepthwiseConvRound<DepthwiseConvOutputRounding::kUpward>(
+ int32 x, int32 quantized_multiplier, int shift) {
+ using gemmlowp::SaturatingRoundingDoublingHighMul;
+ const int left_shift = shift > 0 ? shift : 0;
+ const int right_shift = shift > 0 ? 0 : -shift;
+ const int rounding_offset = right_shift > 0 ? 1 << (right_shift - 1) : 0;
+ return (SaturatingRoundingDoublingHighMul(x * (1 << left_shift),
+ quantized_multiplier) +
+ rounding_offset) >>
+ right_shift;
+}
+
+template <DepthwiseConvOutputRounding output_rounding>
+struct DepthwiseConvBasicKernel {
+ static inline void Run(const DepthwiseParams& params,
+ const RuntimeShape& input_shape,
+ const uint8* input_data,
+ const RuntimeShape& filter_shape,
+ const uint8* filter_data,
+ const RuntimeShape& bias_shape, const int32* bias_data,
+ const RuntimeShape& output_shape, uint8* output_data) {
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ const int depth_multiplier = params.depth_multiplier;
+ const int32 output_activation_min = params.quantized_activation_min;
+ const int32 output_activation_max = params.quantized_activation_max;
+ const int32 input_offset = params.input_offset;
+ const int32 filter_offset = params.weights_offset;
+ const int32 output_offset = params.output_offset;
+ const int32 output_multiplier = params.output_multiplier;
+ const int output_shift = params.output_shift;
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int input_depth = input_shape.Dims(3);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
+ TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
+
+ for (int b = 0; b < batches; ++b) {
+ for (int out_y = 0; out_y < output_height; ++out_y) {
+ for (int out_x = 0; out_x < output_width; ++out_x) {
+ for (int ic = 0; ic < input_depth; ++ic) {
+ for (int m = 0; m < depth_multiplier; m++) {
+ const int oc = m + ic * depth_multiplier;
+ const int in_x_origin = (out_x * stride_width) - pad_width;
+ const int in_y_origin = (out_y * stride_height) - pad_height;
+ int32 acc = 0;
+ for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
+ for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
+ const int in_x =
+ in_x_origin + dilation_width_factor * filter_x;
+ const int in_y =
+ in_y_origin + dilation_height_factor * filter_y;
+ // If the location is outside the bounds of the input image,
+ // use zero as a default value.
+ if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
+ (in_y < input_height)) {
+ int32 input_val =
+ input_data[Offset(input_shape, b, in_y, in_x, ic)];
+ int32 filter_val = filter_data[Offset(
+ filter_shape, 0, filter_y, filter_x, oc)];
+ acc += (filter_val + filter_offset) *
+ (input_val + input_offset);
+ }
+ }
+ }
+ if (bias_data) {
+ acc += bias_data[oc];
+ }
+ acc = DepthwiseConvRound<output_rounding>(acc, output_multiplier,
+ output_shift);
+ acc += output_offset;
+ acc = std::max(acc, output_activation_min);
+ acc = std::min(acc, output_activation_max);
+ output_data[Offset(output_shape, b, out_y, out_x, oc)] =
+ static_cast<uint8>(acc);
+ }
+ }
+ }
+ }
+ }
+ }
+};
+
inline void DepthwiseConv(
const DepthwiseParams& params, const RuntimeShape& input_shape,
const uint8* input_data, const RuntimeShape& filter_shape,
const uint8* filter_data, const RuntimeShape& bias_shape,
const int32* bias_data, const RuntimeShape& output_shape,
uint8* output_data) {
- const int stride_width = params.stride_width;
- const int stride_height = params.stride_height;
- const int dilation_width_factor = params.dilation_width_factor;
- const int dilation_height_factor = params.dilation_height_factor;
- const int pad_width = params.padding_values.width;
- const int pad_height = params.padding_values.height;
- const int depth_multiplier = params.depth_multiplier;
- const int32 output_activation_min = params.quantized_activation_min;
- const int32 output_activation_max = params.quantized_activation_max;
- const int32 input_offset = params.input_offset;
- const int32 filter_offset = params.weights_offset;
- const int32 output_offset = params.output_offset;
- const int32 output_multiplier = params.output_multiplier;
- const int output_shift = params.output_shift;
- TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
- TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
- TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
-
- TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
- const int batches = MatchingDim(input_shape, 0, output_shape, 0);
- const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
- const int input_height = input_shape.Dims(1);
- const int input_width = input_shape.Dims(2);
- const int input_depth = input_shape.Dims(3);
- const int filter_height = filter_shape.Dims(1);
- const int filter_width = filter_shape.Dims(2);
- const int output_height = output_shape.Dims(1);
- const int output_width = output_shape.Dims(2);
- TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
- TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
-
- for (int b = 0; b < batches; ++b) {
- for (int out_y = 0; out_y < output_height; ++out_y) {
- for (int out_x = 0; out_x < output_width; ++out_x) {
- for (int ic = 0; ic < input_depth; ++ic) {
- for (int m = 0; m < depth_multiplier; m++) {
- const int oc = m + ic * depth_multiplier;
- const int in_x_origin = (out_x * stride_width) - pad_width;
- const int in_y_origin = (out_y * stride_height) - pad_height;
- int32 acc = 0;
- for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
- for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
- const int in_x = in_x_origin + dilation_width_factor * filter_x;
- const int in_y =
- in_y_origin + dilation_height_factor * filter_y;
- // If the location is outside the bounds of the input image,
- // use zero as a default value.
- if ((in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
- (in_y < input_height)) {
- int32 input_val =
- input_data[Offset(input_shape, b, in_y, in_x, ic)];
- int32 filter_val = filter_data[Offset(
- filter_shape, 0, filter_y, filter_x, oc)];
- acc +=
- (filter_val + filter_offset) * (input_val + input_offset);
- }
- }
- }
- if (bias_data) {
- acc += bias_data[oc];
- }
- acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
- output_shift);
- acc += output_offset;
- acc = std::max(acc, output_activation_min);
- acc = std::min(acc, output_activation_max);
- output_data[Offset(output_shape, b, out_y, out_x, oc)] =
- static_cast<uint8>(acc);
- }
- }
- }
- }
- }
+ return DepthwiseConvBasicKernel<
+ DepthwiseConvOutputRounding::kAwayFromZero>::Run(params, input_shape,
+ input_data, filter_shape,
+ filter_data, bias_shape,
+ bias_data, output_shape,
+ output_data);
}
-} // end namespace reference_ops
+} // namespace reference_ops
} // end namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_
diff --git a/tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h b/tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h
new file mode 100644
index 0000000..90a7d61
--- /dev/null
+++ b/tensorflow/lite/kernels/internal/reference/integer_ops/depthwise_conv.h
@@ -0,0 +1,125 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_DEPTHWISE_CONV_H_
+#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_DEPTHWISE_CONV_H_
+
+#include "tensorflow/lite/kernels/internal/common.h"
+
+namespace tflite {
+namespace reference_integer_ops {
+inline void DepthwiseConvPerChannel(
+ const DepthwiseParams& params, const int32* output_multiplier,
+ const int32* output_shift, const RuntimeShape& input_shape,
+ const int8* input_data, const RuntimeShape& filter_shape,
+ const int8* filter_data, const RuntimeShape& bias_shape,
+ const int32* bias_data, const RuntimeShape& output_shape,
+ int8* output_data) {
+ // Get parameters.
+ const int stride_width = params.stride_width;
+ const int stride_height = params.stride_height;
+ const int dilation_width_factor = params.dilation_width_factor;
+ const int dilation_height_factor = params.dilation_height_factor;
+ const int pad_width = params.padding_values.width;
+ const int pad_height = params.padding_values.height;
+ const int depth_multiplier = params.depth_multiplier;
+ const int32 input_offset = params.input_offset;
+ const int32 output_offset = params.output_offset;
+
+ // Set min and max value of the output.
+ const int32 output_activation_min = std::numeric_limits<int8_t>::min();
+ const int32 output_activation_max = std::numeric_limits<int8_t>::max();
+
+ // Check dimensions of the tensors.
+ TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
+ TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
+
+ TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
+ const int batches = MatchingDim(input_shape, 0, output_shape, 0);
+ const int output_depth = MatchingDim(filter_shape, 3, output_shape, 3);
+ const int input_height = input_shape.Dims(1);
+ const int input_width = input_shape.Dims(2);
+ const int input_depth = input_shape.Dims(3);
+ const int filter_height = filter_shape.Dims(1);
+ const int filter_width = filter_shape.Dims(2);
+ const int output_height = output_shape.Dims(1);
+ const int output_width = output_shape.Dims(2);
+ TFLITE_DCHECK_EQ(output_depth, input_depth * depth_multiplier);
+ TFLITE_DCHECK_EQ(bias_shape.FlatSize(), output_depth);
+
+ for (int batch = 0; batch < batches; ++batch) {
+ for (int out_y = 0; out_y < output_height; ++out_y) {
+ for (int out_x = 0; out_x < output_width; ++out_x) {
+ for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
+ for (int m = 0; m < depth_multiplier; ++m) {
+ const int output_channel = m + in_channel * depth_multiplier;
+ const int in_x_origin = (out_x * stride_width) - pad_width;
+ const int in_y_origin = (out_y * stride_height) - pad_height;
+ int32 acc = 0;
+ for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
+ for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
+ const int in_x = in_x_origin + dilation_width_factor * filter_x;
+ const int in_y =
+ in_y_origin + dilation_height_factor * filter_y;
+ // Zero padding by omitting the areas outside the image.
+ const bool is_point_inside_image =
+ (in_x >= 0) && (in_x < input_width) && (in_y >= 0) &&
+ (in_y < input_height);
+ if (is_point_inside_image) {
+ int32 input_val = input_data[Offset(input_shape, batch, in_y,
+ in_x, in_channel)];
+ int32 filter_val = filter_data[Offset(
+ filter_shape, 0, filter_y, filter_x, output_channel)];
+ // Accumulate with 32 bits accumulator.
+ // In the nudging process during model quantization, we force
+ // real value of 0.0 be represented by a quantized value. This
+ // guarentees that the input_offset is a int8, even though it
+ // is represented using int32.
+ // int32 += int8 * (int8 - int8) so the highest value we can
+ // get from each accumulation is [-127, 127] * ([-128, 127] -
+ // [-128, 127]), which is [-32512, 32512]. log2(32512)
+ // = 14.98, which means we can accumulate at least 2^16
+ // multiplications without overflow. The accumulator is
+ // applied to a filter so the accumation logic will hold as
+ // long as the filter size (filter_y * filter_x * in_channel)
+ // does not exceed 2^16, which is the case in all the models
+ // we have seen so far.
+ // TODO(jianlijianli): Add a check to make sure the
+ // accumulator depth is smaller than 2^16.
+ acc += filter_val * (input_val - input_offset);
+ }
+ }
+ }
+ if (bias_data) {
+ acc += bias_data[output_channel];
+ }
+ acc = MultiplyByQuantizedMultiplier(
+ acc, output_multiplier[output_channel],
+ output_shift[output_channel]);
+ acc += output_offset;
+ acc = std::max(acc, output_activation_min);
+ acc = std::min(acc, output_activation_max);
+ output_data[Offset(output_shape, batch, out_y, out_x,
+ output_channel)] = static_cast<int8_t>(acc);
+ }
+ }
+ }
+ }
+ }
+}
+} // namespace reference_integer_ops
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_INTEGER_OPS_DEPTHWISE_CONV_H_
diff --git a/tensorflow/lite/kernels/internal/reference/reference_ops.h b/tensorflow/lite/kernels/internal/reference/reference_ops.h
index ac6905f..84f62b1 100644
--- a/tensorflow/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/lite/kernels/internal/reference/reference_ops.h
@@ -702,6 +702,22 @@
}
}
+// T is expected to be either float or int.
+template <typename T>
+inline void AddN(const RuntimeShape& input_shape, const size_t num_inputs,
+ T* const* input_data, T* output_data) {
+ // All inputs and output should have the same shape, this is checked during
+ // Prepare stage.
+ const size_t size = input_shape.FlatSize();
+ for (int i = 0; i < size; ++i) {
+ T x = 0;
+ for (int j = 0; j < num_inputs; ++j) {
+ x += input_data[j][i];
+ }
+ output_data[i] = x;
+ }
+}
+
// Element-wise add that can often be used for inner loop of broadcast add as
// well as the non-broadcast add.
inline void AddElementwise(int size, const ArithmeticParams& params,
diff --git a/tensorflow/lite/kernels/layer_norm_lstm.cc b/tensorflow/lite/kernels/layer_norm_lstm.cc
deleted file mode 100644
index ce0c21d..0000000
--- a/tensorflow/lite/kernels/layer_norm_lstm.cc
+++ /dev/null
@@ -1,1324 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// DEPRECATED: Tensorflow Lite has implemented layer norm lstm as builtin Op and
-// the implementation of layer norm lstm as custom Op in this file is
-// deprecated. It is only kept for backward compatibility.
-//
-// Layer Normalization LSTM op that applies normalization by mean and standard
-// deviation to the activation of the LSTM layers. Please see
-// https://arxiv.org/abs/1607.06450 for details.
-#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
-#include "tensorflow/lite/context.h"
-#include "tensorflow/lite/kernels/internal/tensor_utils.h"
-#include "tensorflow/lite/kernels/kernel_util.h"
-
-namespace tflite {
-namespace ops {
-namespace custom {
-namespace layer_norm_lstm {
-
-// Struct to hold Layer Norm LSTM option data.
-struct OpData {
- TfLiteFusedActivation activation;
- float cell_clip;
- float proj_clip;
- int scratch_tensor_index;
-};
-
-// Input Tensors of size {n_batch, n_input}
-constexpr int kInputTensor = 0;
-
-// Input weight tensors of size: {n_cell, n_input}
-constexpr int kInputToInputWeightsTensor = 1; // Optional
-constexpr int kInputToForgetWeightsTensor = 2;
-constexpr int kInputToCellWeightsTensor = 3;
-constexpr int kInputToOutputWeightsTensor = 4;
-
-// Recurrent weight tensors of size {n_cell, n_output}
-constexpr int kRecurrentToInputWeightsTensor = 5; // Optional
-constexpr int kRecurrentToForgetWeightsTensor = 6;
-constexpr int kRecurrentToCellWeightsTensor = 7;
-constexpr int kRecurrentToOutputWeightsTensor = 8;
-
-// Peephole weights tensors of size {n_cell}, representing a diagonal matrix.
-constexpr int kCellToInputWeightsTensor = 9; // Optional
-constexpr int kCellToForgetWeightsTensor = 10; // Optional
-constexpr int kCellToOutputWeightsTensor = 11; // Optional
-
-// Layer norm weights tensors of size {n_cell}, representing a diagonal matrix.
-constexpr int kInputLayerNormWeightsTensor = 12; // Optional
-constexpr int kForgetLayerNormWeightsTensor = 13;
-constexpr int kCellLayerNormWeightsTensor = 14;
-constexpr int kOutputLayerNormWeightsTensor = 15;
-
-// Gates bias tensors of size {n_cell}
-constexpr int kInputGateBiasTensor = 16; // Optional
-constexpr int kForgetGateBiasTensor = 17;
-constexpr int kCellGateBiasTensor = 18;
-constexpr int kOutputGateBiasTensor = 19;
-
-// Projection weight tensor of size {n_output, n_cell}
-constexpr int kProjectionWeightsTensor = 20; // Optional
-// Projection bias tensor of size {n_output}
-constexpr int kProjectionBiasTensor = 21; // Optional
-
-// State tensors.
-constexpr int kInputActivationStateTensor = 22;
-constexpr int kInputCellStateTensor = 23;
-
-// Output tensor.
-constexpr int kOutputTensor = 0;
-
-// Total number of scratch tensors for hybrid Op.
-constexpr int kTensorsToAdd = 7;
-
-// Small float to avoid divergence during calculation of deviation.
-const float kLayerNormEpsilon = 1e-8;
-
-void* Init(TfLiteContext* context, const char* buffer, size_t length) {
- auto* data = new OpData;
-
- // Turn custom option data into flexbuffer map format.
- const uint8_t* buffer_t = reinterpret_cast<const uint8_t*>(buffer);
- const flexbuffers::Map& m = flexbuffers::GetRoot(buffer_t, length).AsMap();
-
- // Get activation function, cell_clip and proj_clip from the flexbuffer.
- // TODO(b/113824099): make activation more generic.
- assert(m["fused_activation_function"].ToString() == "TANH");
- data->activation = kTfLiteActTanh;
- data->cell_clip = m["cell_clip"].AsFloat();
- data->proj_clip = m["proj_clip"].AsFloat();
-
- // Populate scratch_tensor_index.
- context->AddTensors(context, /*tensors_to_add=*/kTensorsToAdd,
- &data->scratch_tensor_index);
- return data;
-}
-
-// Check that input tensor dimensions matches with each other.
-TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
- TfLiteNode* node, int n_input,
- int n_output, int n_cell) {
- const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
-
- // Making sure clipping parameters have valid values.
- // == 0 means no clipping
- // > 0 means clipping
- TF_LITE_ENSURE(context, op_data->cell_clip >= 0);
- TF_LITE_ENSURE(context, op_data->proj_clip >= 0);
-
- const TfLiteTensor* input_to_input_weights =
- GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
- const bool use_cifg = (input_to_input_weights == nullptr);
- if (!use_cifg) {
- TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2);
- TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell);
- TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input);
- }
-
- const TfLiteTensor* input_to_forget_weights =
- GetInput(context, node, kInputToForgetWeightsTensor);
- TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2);
- TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell);
- TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input);
-
- const TfLiteTensor* input_to_cell_weights =
- GetInput(context, node, kInputToCellWeightsTensor);
- TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2);
- TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell);
- TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input);
-
- const TfLiteTensor* recurrent_to_input_weights =
- GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
- if (use_cifg) {
- TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights, nullptr);
- } else {
- TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2);
- TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0],
- n_cell);
- TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1],
- n_output);
- }
-
- const TfLiteTensor* recurrent_to_forget_weights =
- GetInput(context, node, kRecurrentToForgetWeightsTensor);
- TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2);
- TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0],
- n_cell);
- TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1],
- n_output);
-
- const TfLiteTensor* recurrent_to_cell_weights =
- GetInput(context, node, kRecurrentToCellWeightsTensor);
- TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2);
- TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell);
- TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1],
- n_output);
-
- const TfLiteTensor* cell_to_input_weights =
- GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
- if (cell_to_input_weights) {
- TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1);
- TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell);
- }
-
- const TfLiteTensor* cell_to_forget_weights =
- GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
- if (cell_to_forget_weights) {
- TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1);
- TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell);
- }
-
- const TfLiteTensor* cell_to_output_weights =
- GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
- if (cell_to_output_weights) {
- TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1);
- TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell);
- }
-
- // Making sure the peephole weights are there all or none.
- const bool peephole_weights_all_or_none =
- ((cell_to_input_weights != nullptr || use_cifg) &&
- (cell_to_forget_weights != nullptr) &&
- (cell_to_output_weights != nullptr)) ||
- ((cell_to_input_weights == nullptr) &&
- (cell_to_forget_weights == nullptr) &&
- (cell_to_output_weights == nullptr));
- TF_LITE_ENSURE(context, peephole_weights_all_or_none == true);
-
- // Making sure layer norm weights are not null and have the right dimension.
- const TfLiteTensor* input_layer_norm_weights =
- GetOptionalInputTensor(context, node, kInputLayerNormWeightsTensor);
- if (use_cifg) {
- TF_LITE_ENSURE_EQ(context, input_layer_norm_weights, nullptr);
- } else {
- TF_LITE_ENSURE(context, input_layer_norm_weights != nullptr);
- TF_LITE_ENSURE_EQ(context, input_layer_norm_weights->dims->size, 1);
- TF_LITE_ENSURE_EQ(context, input_layer_norm_weights->dims->data[0], n_cell);
- }
-
- const TfLiteTensor* forget_layer_norm_weights =
- GetInput(context, node, kForgetLayerNormWeightsTensor);
- TF_LITE_ENSURE(context, forget_layer_norm_weights != nullptr);
- TF_LITE_ENSURE_EQ(context, forget_layer_norm_weights->dims->size, 1);
- TF_LITE_ENSURE_EQ(context, forget_layer_norm_weights->dims->data[0], n_cell);
-
- const TfLiteTensor* cell_layer_norm_weights =
- GetInput(context, node, kCellLayerNormWeightsTensor);
- TF_LITE_ENSURE(context, cell_layer_norm_weights != nullptr);
- TF_LITE_ENSURE_EQ(context, cell_layer_norm_weights->dims->size, 1);
- TF_LITE_ENSURE_EQ(context, cell_layer_norm_weights->dims->data[0], n_cell);
-
- const TfLiteTensor* output_layer_norm_weights =
- GetInput(context, node, kOutputLayerNormWeightsTensor);
- TF_LITE_ENSURE(context, output_layer_norm_weights != nullptr);
- TF_LITE_ENSURE_EQ(context, output_layer_norm_weights->dims->size, 1);
- TF_LITE_ENSURE_EQ(context, output_layer_norm_weights->dims->data[0], n_cell);
-
- // Make sure the input gate bias is present only when not a CIFG-LSTM.
- const TfLiteTensor* input_gate_bias =
- GetOptionalInputTensor(context, node, kInputGateBiasTensor);
- if (use_cifg) {
- TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr);
- } else {
- TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1);
- TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell);
- }
-
- const TfLiteTensor* forget_gate_bias =
- GetInput(context, node, kForgetGateBiasTensor);
- TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1);
- TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell);
-
- const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
- TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1);
- TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell);
-
- const TfLiteTensor* output_gate_bias =
- GetInput(context, node, kOutputGateBiasTensor);
- TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1);
- TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell);
-
- const TfLiteTensor* projection_weights =
- GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
- if (projection_weights != nullptr) {
- TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2);
- TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output);
- TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell);
- }
-
- const TfLiteTensor* projection_bias =
- GetOptionalInputTensor(context, node, kProjectionBiasTensor);
- if (projection_bias != nullptr) {
- TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1);
- TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output);
- }
-
- // Making sure the projection tensors are consistent:
- // 1) If projection weight is not present, then projection bias should not be
- // present.
- // 2) If projection weight is present, then projection bias is optional.
- const bool projection_tensors_consistent =
- ((projection_weights != nullptr) || (projection_bias == nullptr));
- TF_LITE_ENSURE(context, projection_tensors_consistent == true);
-
- return kTfLiteOk;
-}
-
-// Resize the output, state tensors based on the sizes of the input tensors.
-// Allocate a temporary scratch tensor. Also check that the sizes of the input
-// tensors match each other.
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
- TF_LITE_ENSURE_EQ(context, node->inputs->size, 24);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
-
- // Inferring batch size, number of outputs and number of cells from the
- // input tensors.
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
- TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
- TF_LITE_ENSURE(context, input->dims->size > 1);
- const int n_batch = input->dims->data[0];
- const int n_input = input->dims->data[1];
-
- const TfLiteTensor* input_to_output_weights =
- GetInput(context, node, kInputToOutputWeightsTensor);
- const int n_cell = input_to_output_weights->dims->data[0];
- TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2);
- TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input);
-
- const TfLiteTensor* recurrent_to_output_weights =
- GetInput(context, node, kRecurrentToOutputWeightsTensor);
- TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2);
- TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0],
- n_cell);
- const int n_output = recurrent_to_output_weights->dims->data[1];
-
- // Check that input tensor dimensions matches with each other.
- TF_LITE_ENSURE_OK(context, CheckInputTensorDimensions(context, node, n_input,
- n_output, n_cell));
-
- // Get the pointer to output, activation_state and cell_state tensors.
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-
- const TfLiteTensor* activation_state =
- GetInput(context, node, kInputActivationStateTensor);
- const TfLiteTensor* cell_state =
- GetInput(context, node, kInputCellStateTensor);
-
- // Check the shape of input state tensors.
- // These tensor may be 1D or 2D. It's fine as long as the total size is
- // correct.
- TF_LITE_ENSURE_EQ(context, NumElements(activation_state), n_batch * n_output);
- TF_LITE_ENSURE_EQ(context, NumElements(cell_state), n_batch * n_cell);
- // Resize the output tensors.
- TfLiteIntArray* output_size = TfLiteIntArrayCreate(2);
- output_size->data[0] = n_batch;
- output_size->data[1] = n_output;
- TF_LITE_ENSURE_OK(context,
- context->ResizeTensor(context, output, output_size));
-
- // The weights are of consistent type, so it suffices to check one.
- const bool is_hybrid_op = (input_to_output_weights->type == kTfLiteUInt8 &&
- input->type == kTfLiteFloat32);
-
- TfLiteIntArrayFree(node->temporaries);
- if (is_hybrid_op) {
- node->temporaries = TfLiteIntArrayCreate(7);
- } else {
- node->temporaries = TfLiteIntArrayCreate(1);
- }
- node->temporaries->data[0] = op_data->scratch_tensor_index;
-
- // Create a scratch buffer tensor.
- TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
- scratch_buffer->type = input->type;
- scratch_buffer->allocation_type = kTfLiteArenaRw;
-
- const TfLiteTensor* input_to_input_weights =
- GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
- const bool use_cifg = (input_to_input_weights == nullptr);
- TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2);
- scratch_buffer_size->data[0] = n_batch;
- if (use_cifg) {
- // Reserving space for Cell, Forget, Output gates
- scratch_buffer_size->data[1] = n_cell * 3;
- } else {
- // Reserving space for Input, Cell, Forget, Output gates
- scratch_buffer_size->data[1] = n_cell * 4;
- }
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer,
- scratch_buffer_size));
-
- if (is_hybrid_op) {
- // Allocate temporary tensors to store quantized values of input,
- // activation_state and cell_state tensors.
- node->temporaries->data[1] = op_data->scratch_tensor_index + 1;
- TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
- input_quantized->type = kTfLiteUInt8;
- input_quantized->allocation_type = kTfLiteArenaRw;
- if (!TfLiteIntArrayEqual(input_quantized->dims, input->dims)) {
- TfLiteIntArray* input_quantized_size = TfLiteIntArrayCopy(input->dims);
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, input_quantized,
- input_quantized_size));
- }
- node->temporaries->data[2] = op_data->scratch_tensor_index + 2;
- TfLiteTensor* activation_state_quantized =
- GetTemporary(context, node, /*index=*/2);
- activation_state_quantized->type = kTfLiteUInt8;
- activation_state_quantized->allocation_type = kTfLiteArenaRw;
- if (!TfLiteIntArrayEqual(activation_state_quantized->dims,
- activation_state->dims)) {
- TfLiteIntArray* activation_state_quantized_size =
- TfLiteIntArrayCopy(activation_state->dims);
- TF_LITE_ENSURE_OK(
- context, context->ResizeTensor(context, activation_state_quantized,
- activation_state_quantized_size));
- }
- node->temporaries->data[3] = op_data->scratch_tensor_index + 3;
- TfLiteTensor* cell_state_quantized =
- GetTemporary(context, node, /*index=*/3);
- cell_state_quantized->type = kTfLiteUInt8;
- cell_state_quantized->allocation_type = kTfLiteArenaRw;
- if (!TfLiteIntArrayEqual(cell_state_quantized->dims, cell_state->dims)) {
- TfLiteIntArray* cell_state_quantized_size =
- TfLiteIntArrayCopy(cell_state->dims);
- TF_LITE_ENSURE_OK(context,
- context->ResizeTensor(context, cell_state_quantized,
- cell_state_quantized_size));
- }
-
- // Allocate temporary tensors to store scaling factors and product scaling
- // factors. The latter is a convenience storage which allows to quantize
- // a vector once (which produces the scaling factors) and multiply it with
- // different matrices (which requires multiplying the scaling factors with
- // the scaling factor of the matrix).
- node->temporaries->data[4] = op_data->scratch_tensor_index + 4;
- TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4);
- scaling_factors->type = kTfLiteFloat32;
- scaling_factors->allocation_type = kTfLiteArenaRw;
- int scaling_dims[1] = {n_batch};
- if (!TfLiteIntArrayEqualsArray(scaling_factors->dims, 1, scaling_dims)) {
- TfLiteIntArray* scaling_factors_size = TfLiteIntArrayCreate(1);
- scaling_factors_size->data[0] = n_batch;
- TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scaling_factors,
- scaling_factors_size));
- }
- node->temporaries->data[5] = op_data->scratch_tensor_index + 5;
- TfLiteTensor* prod_scaling_factors =
- GetTemporary(context, node, /*index=*/5);
- prod_scaling_factors->type = kTfLiteFloat32;
- prod_scaling_factors->allocation_type = kTfLiteArenaRw;
- if (!TfLiteIntArrayEqualsArray(prod_scaling_factors->dims, 1,
- scaling_dims)) {
- TfLiteIntArray* prod_scaling_factors_size = TfLiteIntArrayCreate(1);
- prod_scaling_factors_size->data[0] = n_batch;
- TF_LITE_ENSURE_OK(context,
- context->ResizeTensor(context, prod_scaling_factors,
- prod_scaling_factors_size));
- }
-
- // Allocate a temporary tensor to store the recovered weights. Since
- // this is used for diagonal matrices, only need to store n_cell values.
- node->temporaries->data[6] = op_data->scratch_tensor_index + 6;
- TfLiteTensor* recovered_weights = GetTemporary(context, node, /*index=*/6);
- recovered_weights->type = kTfLiteFloat32;
- recovered_weights->allocation_type = kTfLiteArenaRw;
- int recovered_dims[1] = {n_cell};
- if (!TfLiteIntArrayEqualsArray(recovered_weights->dims, 1,
- recovered_dims)) {
- TfLiteIntArray* recovered_weights_size = TfLiteIntArrayCreate(1);
- recovered_weights_size->data[0] = n_cell;
- TF_LITE_ENSURE_OK(context,
- context->ResizeTensor(context, recovered_weights,
- recovered_weights_size));
- }
- }
- return kTfLiteOk;
-}
-
-void LayerNormLstmStep(
- const float* input_ptr_batch, const float* input_to_input_weights_ptr,
- const float* input_to_forget_weights_ptr,
- const float* input_to_cell_weights_ptr,
- const float* input_to_output_weights_ptr,
- const float* recurrent_to_input_weights_ptr,
- const float* recurrent_to_forget_weights_ptr,
- const float* recurrent_to_cell_weights_ptr,
- const float* recurrent_to_output_weights_ptr,
- const float* cell_to_input_weights_ptr,
- const float* cell_to_forget_weights_ptr,
- const float* cell_to_output_weights_ptr,
- const float* input_layer_norm_weight_ptr,
- const float* forget_layer_norm_weight_ptr,
- const float* cell_layer_norm_weight_ptr,
- const float* output_layer_norm_weight_ptr, const float* input_gate_bias_ptr,
- const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
- const float* output_gate_bias_ptr, const float* projection_weights_ptr,
- const float* projection_bias_ptr, float cell_clip, float proj_clip,
- const TfLiteFusedActivation& activation, int n_batch, int n_cell,
- int n_input, int n_output, float* output_state_ptr, float* cell_state_ptr,
- float* input_gate_scratch, float* forget_gate_scratch, float* cell_scratch,
- float* output_gate_scratch, float* output_ptr_batch) {
- // Since we have already checked that weights are all there or none, we can
- // check the existense of only one to the get the condition.
- const bool use_cifg = (input_to_input_weights_ptr == nullptr);
- const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
-
- // Initialize scratch buffers with 0.
- if (!use_cifg) {
- tensor_utils::ZeroVector(input_gate_scratch, n_cell * n_batch);
- }
- tensor_utils::ZeroVector(forget_gate_scratch, n_cell * n_batch);
- tensor_utils::ZeroVector(cell_scratch, n_cell * n_batch);
- tensor_utils::ZeroVector(output_gate_scratch, n_cell * n_batch);
-
- // For each batch and cell: compute input_weight * input.
- if (!use_cifg) {
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_input_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
- input_gate_scratch, /*result_stride=*/1);
- }
-
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_forget_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
- forget_gate_scratch, /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_cell_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
- cell_scratch, /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_output_weights_ptr, n_cell, n_input, input_ptr_batch, n_batch,
- output_gate_scratch, /*result_stride=*/1);
-
- // For each batch and cell: compute recurrent_weight * output_state.
- if (!use_cifg) {
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_input_weights_ptr, n_cell, n_output, output_state_ptr,
- n_batch, input_gate_scratch, /*result_stride=*/1);
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_forget_weights_ptr, n_cell, n_output, output_state_ptr,
- n_batch, forget_gate_scratch,
- /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_cell_weights_ptr, n_cell, n_output, output_state_ptr,
- n_batch, cell_scratch, /*result_stride=*/1);
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_output_weights_ptr, n_cell, n_output, output_state_ptr,
- n_batch, output_gate_scratch,
- /*result_stride=*/1);
-
- // For each batch and cell: update input gate.
- if (!use_cifg) {
- if (use_peephole) {
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- cell_to_input_weights_ptr, n_cell, cell_state_ptr, n_batch,
- input_gate_scratch);
- }
- tensor_utils::MeanStddevNormalization(input_gate_scratch,
- input_gate_scratch, n_cell, n_batch,
- kLayerNormEpsilon);
- tensor_utils::VectorBatchVectorCwiseProduct(input_layer_norm_weight_ptr,
- n_cell, input_gate_scratch,
- n_batch, input_gate_scratch);
- tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch,
- input_gate_scratch);
- tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
- input_gate_scratch);
- }
-
- // For each batch and cell: update forget gate.
- if (use_peephole) {
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- cell_to_forget_weights_ptr, n_cell, cell_state_ptr, n_batch,
- forget_gate_scratch);
- }
- tensor_utils::MeanStddevNormalization(forget_gate_scratch,
- forget_gate_scratch, n_cell, n_batch,
- kLayerNormEpsilon);
- tensor_utils::VectorBatchVectorCwiseProduct(forget_layer_norm_weight_ptr,
- n_cell, forget_gate_scratch,
- n_batch, forget_gate_scratch);
- tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch,
- forget_gate_scratch);
- tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
- forget_gate_scratch);
-
- // For each batch and cell: update the cell.
- tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
- n_batch, kLayerNormEpsilon);
- tensor_utils::VectorBatchVectorCwiseProduct(
- cell_layer_norm_weight_ptr, n_cell, cell_scratch, n_batch, cell_scratch);
- tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch,
- cell_scratch);
- tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
- n_batch * n_cell, cell_state_ptr);
- tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
- activation, cell_scratch);
- if (use_cifg) {
- tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
- forget_gate_scratch);
- tensor_utils::VectorVectorCwiseProductAccumulate(
- cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
- } else {
- tensor_utils::VectorVectorCwiseProductAccumulate(
- cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
- }
- if (cell_clip > 0.0) {
- tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, cell_clip,
- cell_state_ptr);
- }
-
- // For each batch and cell: update the output gate.
- if (use_peephole) {
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- cell_to_output_weights_ptr, n_cell, cell_state_ptr, n_batch,
- output_gate_scratch);
- }
- tensor_utils::MeanStddevNormalization(output_gate_scratch,
- output_gate_scratch, n_cell, n_batch,
- kLayerNormEpsilon);
- tensor_utils::VectorBatchVectorCwiseProduct(output_layer_norm_weight_ptr,
- n_cell, output_gate_scratch,
- n_batch, output_gate_scratch);
- tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch,
- output_gate_scratch);
- tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
- output_gate_scratch);
- tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
- activation, cell_scratch);
- tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
- n_batch * n_cell, output_gate_scratch);
-
- // For each batch: update the projection and output_state.
- const bool use_projection_weight = (projection_weights_ptr != nullptr);
- const bool use_projection_bias = (projection_bias_ptr != nullptr);
- if (use_projection_weight) {
- if (use_projection_bias) {
- tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
- n_batch, output_ptr_batch);
- } else {
- tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- projection_weights_ptr, n_output, n_cell, output_gate_scratch, n_batch,
- output_ptr_batch, /*result_stride=*/1);
- if (proj_clip > 0.0) {
- tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output, proj_clip,
- output_ptr_batch);
- }
- } else {
- tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
- output_ptr_batch);
- }
- tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
- output_state_ptr);
-}
-
-void LayerNormLstmStep(
- const float* input_ptr_batch, const int8_t* input_to_input_weights_ptr,
- float input_to_input_weights_scale,
- const int8_t* input_to_forget_weights_ptr,
- float input_to_forget_weights_scale,
- const int8_t* input_to_cell_weights_ptr, float input_to_cell_weights_scale,
- const int8_t* input_to_output_weights_ptr,
- float input_to_output_weights_scale,
- const int8_t* recurrent_to_input_weights_ptr,
- float recurrent_to_input_weights_scale,
- const int8_t* recurrent_to_forget_weights_ptr,
- float recurrent_to_forget_weights_scale,
- const int8_t* recurrent_to_cell_weights_ptr,
- float recurrent_to_cell_weights_scale,
- const int8_t* recurrent_to_output_weights_ptr,
- float recurrent_to_output_weights_scale,
- const int8_t* cell_to_input_weights_ptr, float cell_to_input_weights_scale,
- const int8_t* cell_to_forget_weights_ptr,
- float cell_to_forget_weights_scale,
- const int8_t* cell_to_output_weights_ptr,
- float cell_to_output_weights_scale,
- const float* input_layer_norm_weight_ptr,
- const float* forget_layer_norm_weight_ptr,
- const float* cell_layer_norm_weight_ptr,
- const float* output_layer_norm_weight_ptr, const float* input_gate_bias_ptr,
- const float* forget_gate_bias_ptr, const float* cell_bias_ptr,
- const float* output_gate_bias_ptr, const int8_t* projection_weights_ptr,
- float projection_weights_scale, const float* projection_bias_ptr,
- float cell_clip, float proj_clip, const TfLiteFusedActivation& activation,
- int n_batch, int n_cell, int n_input, int n_output,
- float* input_gate_scratch, float* forget_gate_scratch, float* cell_scratch,
- float* output_gate_scratch, float* scaling_factors,
- float* product_scaling_factors, float* recovered_weights,
- int8_t* quantized_input_ptr_batch, int8_t* quantized_output_state_ptr,
- int8_t* quantized_cell_state_ptr, float* output_state_ptr,
- float* cell_state_ptr, float* output_ptr_batch) {
- // Since we have already checked that weights are all there or none, we can
- // check the existense of only one to the get the condition.
- const bool use_cifg = (input_to_input_weights_ptr == nullptr);
- const bool use_peephole = (cell_to_output_weights_ptr != nullptr);
-
- // Initialize scratch buffers with 0.
- if (!use_cifg) {
- tensor_utils::ZeroVector(input_gate_scratch, n_cell * n_batch);
- }
- tensor_utils::ZeroVector(forget_gate_scratch, n_cell * n_batch);
- tensor_utils::ZeroVector(cell_scratch, n_cell * n_batch);
- tensor_utils::ZeroVector(output_gate_scratch, n_cell * n_batch);
-
- if (!tensor_utils::IsZeroVector(input_ptr_batch, n_batch * n_input)) {
- // Save quantization and matmul computation for all zero input.
- float unused_min, unused_max;
- for (int b = 0; b < n_batch; ++b) {
- const int offset = b * n_input;
- tensor_utils::SymmetricQuantizeFloats(
- input_ptr_batch + offset, n_input, quantized_input_ptr_batch + offset,
- &unused_min, &unused_max, &scaling_factors[b]);
- }
- // For each batch and cell: compute input_weight * input.
- if (!use_cifg) {
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * input_to_input_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_input_weights_ptr, n_cell, n_input,
- quantized_input_ptr_batch, product_scaling_factors, n_batch,
- input_gate_scratch, /*result_stride=*/1);
- }
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * input_to_forget_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_forget_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
- product_scaling_factors, n_batch, forget_gate_scratch,
- /*result_stride=*/1);
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * input_to_cell_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_cell_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
- product_scaling_factors, n_batch, cell_scratch, /*result_stride=*/1);
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * input_to_output_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- input_to_output_weights_ptr, n_cell, n_input, quantized_input_ptr_batch,
- product_scaling_factors, n_batch, output_gate_scratch,
- /*result_stride=*/1);
- }
-
- if (!tensor_utils::IsZeroVector(output_state_ptr, n_batch * n_output)) {
- // Save quantization and matmul computation for all zero input.
- float unused_min, unused_max;
- for (int b = 0; b < n_batch; ++b) {
- const int offset = b * n_output;
- tensor_utils::SymmetricQuantizeFloats(output_state_ptr + offset, n_output,
- quantized_output_state_ptr + offset,
- &unused_min, &unused_max,
- &scaling_factors[b]);
- }
- // For each batch and cell: compute recurrent_weight * output_state.
- if (!use_cifg) {
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * recurrent_to_input_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_input_weights_ptr, n_cell, n_output,
- quantized_output_state_ptr, product_scaling_factors, n_batch,
- input_gate_scratch, /*result_stride=*/1);
- }
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * recurrent_to_forget_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_forget_weights_ptr, n_cell, n_output,
- quantized_output_state_ptr, product_scaling_factors, n_batch,
- forget_gate_scratch, /*result_stride=*/1);
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * recurrent_to_cell_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_cell_weights_ptr, n_cell, n_output,
- quantized_output_state_ptr, product_scaling_factors, n_batch,
- cell_scratch, /*result_stride=*/1);
-
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * recurrent_to_output_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- recurrent_to_output_weights_ptr, n_cell, n_output,
- quantized_output_state_ptr, product_scaling_factors, n_batch,
- output_gate_scratch, /*result_stride=*/1);
- }
-
- // Save quantization and matmul computation for all zero input.
- bool is_cell_state_all_zeros =
- tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
-
- // For each batch and cell: update input gate.
- if (!use_cifg) {
- if (use_peephole && !is_cell_state_all_zeros) {
- tensor_utils::VectorScalarMultiply(cell_to_input_weights_ptr, n_cell,
- cell_to_input_weights_scale,
- recovered_weights);
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- recovered_weights, n_cell, cell_state_ptr, n_batch,
- input_gate_scratch);
- }
- tensor_utils::MeanStddevNormalization(input_gate_scratch,
- input_gate_scratch, n_cell, n_batch,
- kLayerNormEpsilon);
- tensor_utils::VectorBatchVectorCwiseProduct(input_layer_norm_weight_ptr,
- n_cell, input_gate_scratch,
- n_batch, input_gate_scratch);
- tensor_utils::VectorBatchVectorAdd(input_gate_bias_ptr, n_cell, n_batch,
- input_gate_scratch);
- tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch,
- input_gate_scratch);
- }
-
- // For each batch and cell: update forget gate.
- if (use_peephole && !is_cell_state_all_zeros) {
- tensor_utils::VectorScalarMultiply(cell_to_forget_weights_ptr, n_cell,
- cell_to_forget_weights_scale,
- recovered_weights);
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- recovered_weights, n_cell, cell_state_ptr, n_batch,
- forget_gate_scratch);
- }
- tensor_utils::MeanStddevNormalization(forget_gate_scratch,
- forget_gate_scratch, n_cell, n_batch,
- kLayerNormEpsilon);
- tensor_utils::VectorBatchVectorCwiseProduct(forget_layer_norm_weight_ptr,
- n_cell, forget_gate_scratch,
- n_batch, forget_gate_scratch);
- tensor_utils::VectorBatchVectorAdd(forget_gate_bias_ptr, n_cell, n_batch,
- forget_gate_scratch);
- tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch,
- forget_gate_scratch);
-
- // For each batch and cell: update the cell.
- tensor_utils::MeanStddevNormalization(cell_scratch, cell_scratch, n_cell,
- n_batch, kLayerNormEpsilon);
- tensor_utils::VectorBatchVectorCwiseProduct(
- cell_layer_norm_weight_ptr, n_cell, cell_scratch, n_batch, cell_scratch);
- tensor_utils::VectorBatchVectorAdd(cell_bias_ptr, n_cell, n_batch,
- cell_scratch);
- tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, cell_state_ptr,
- n_batch * n_cell, cell_state_ptr);
- tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell,
- activation, cell_scratch);
- if (use_cifg) {
- tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell,
- forget_gate_scratch);
- tensor_utils::VectorVectorCwiseProductAccumulate(
- cell_scratch, forget_gate_scratch, n_batch * n_cell, cell_state_ptr);
- } else {
- tensor_utils::VectorVectorCwiseProductAccumulate(
- cell_scratch, input_gate_scratch, n_batch * n_cell, cell_state_ptr);
- }
- if (cell_clip > 0.0) {
- tensor_utils::ClipVector(cell_state_ptr, n_batch * n_cell, cell_clip,
- cell_state_ptr);
- }
-
- is_cell_state_all_zeros =
- tensor_utils::IsZeroVector(cell_state_ptr, n_batch * n_cell);
- // For each batch and cell: update the output gate.
- if (use_peephole && !is_cell_state_all_zeros) {
- tensor_utils::VectorScalarMultiply(cell_to_output_weights_ptr, n_cell,
- cell_to_output_weights_scale,
- recovered_weights);
- tensor_utils::VectorBatchVectorCwiseProductAccumulate(
- recovered_weights, n_cell, cell_state_ptr, n_batch,
- output_gate_scratch);
- }
- tensor_utils::MeanStddevNormalization(output_gate_scratch,
- output_gate_scratch, n_cell, n_batch,
- kLayerNormEpsilon);
- tensor_utils::VectorBatchVectorCwiseProduct(output_layer_norm_weight_ptr,
- n_cell, output_gate_scratch,
- n_batch, output_gate_scratch);
- tensor_utils::VectorBatchVectorAdd(output_gate_bias_ptr, n_cell, n_batch,
- output_gate_scratch);
- tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell,
- output_gate_scratch);
- tensor_utils::ApplyActivationToVector(cell_state_ptr, n_batch * n_cell,
- activation, cell_scratch);
- tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch,
- n_batch * n_cell, output_gate_scratch);
-
- // For each batch: update the projection and output_state.
- const bool use_projection_weight = (projection_weights_ptr != nullptr);
- const bool use_projection_bias = (projection_bias_ptr != nullptr);
- if (use_projection_weight) {
- if (use_projection_bias) {
- tensor_utils::VectorBatchVectorAssign(projection_bias_ptr, n_output,
- n_batch, output_ptr_batch);
- } else {
- tensor_utils::ZeroVector(output_ptr_batch, n_batch * n_output);
- }
- if (!tensor_utils::IsZeroVector(output_gate_scratch, n_batch * n_cell)) {
- // Save quantization and matmul computation for all zero input.
- float unused_min, unused_max;
- for (int b = 0; b < n_batch; ++b) {
- const int offset = b * n_cell;
- tensor_utils::SymmetricQuantizeFloats(
- output_gate_scratch + offset, n_cell,
- quantized_cell_state_ptr + offset, &unused_min, &unused_max,
- &scaling_factors[b]);
- }
- for (int b = 0; b < n_batch; ++b) {
- product_scaling_factors[b] =
- scaling_factors[b] * projection_weights_scale;
- }
- tensor_utils::MatrixBatchVectorMultiplyAccumulate(
- projection_weights_ptr, n_output, n_cell, quantized_cell_state_ptr,
- product_scaling_factors, n_batch, output_ptr_batch,
- /*result_stride=*/1);
- }
- if (proj_clip > 0.0) {
- tensor_utils::ClipVector(output_ptr_batch, n_batch * n_output, proj_clip,
- output_ptr_batch);
- }
- } else {
- tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output,
- output_ptr_batch);
- }
- tensor_utils::CopyVector(output_ptr_batch, n_batch * n_output,
- output_state_ptr);
-}
-
-// The LayerNormLSTM Op engine.
-TfLiteStatus EvalFloat(
- const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
- const TfLiteTensor* input_to_forget_weights,
- const TfLiteTensor* input_to_cell_weights,
- const TfLiteTensor* input_to_output_weights,
- const TfLiteTensor* recurrent_to_input_weights,
- const TfLiteTensor* recurrent_to_forget_weights,
- const TfLiteTensor* recurrent_to_cell_weights,
- const TfLiteTensor* recurrent_to_output_weights,
- const TfLiteTensor* cell_to_input_weights,
- const TfLiteTensor* cell_to_forget_weights,
- const TfLiteTensor* cell_to_output_weights,
- const TfLiteTensor* input_layer_norm_weights,
- const TfLiteTensor* forget_layer_norm_weights,
- const TfLiteTensor* cell_layer_norm_weights,
- const TfLiteTensor* output_layer_norm_weights,
- const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
- const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
- const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
- float cell_clip, float proj_clip, const TfLiteFusedActivation& activation,
- TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state,
- TfLiteTensor* cell_state, TfLiteTensor* output) {
- const int n_batch = input->dims->data[0];
- const int n_input = input->dims->data[1];
- // n_cell and n_output will be the same size when there is no projection.
- const int n_cell = input_to_output_weights->dims->data[0];
- const int n_output = recurrent_to_output_weights->dims->data[1];
-
- // Since we have already checked that weights are all there or none, we can
- // check the existence of only one to get the condition.
- const bool use_cifg = (input_to_input_weights == nullptr);
- const bool use_peephole = (cell_to_output_weights != nullptr);
-
- float* input_gate_scratch = nullptr;
- float* cell_scratch = nullptr;
- float* forget_gate_scratch = nullptr;
- float* output_gate_scratch = nullptr;
- if (use_cifg) {
- cell_scratch = scratch_buffer->data.f;
- forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- } else {
- input_gate_scratch = scratch_buffer->data.f;
- cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
- forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
- }
-
- // Check optional tensors, the respective pointers can be null.
- const float* input_to_input_weights_ptr =
- (use_cifg) ? nullptr : input_to_input_weights->data.f;
- const float* recurrent_to_input_weights_ptr =
- (use_cifg) ? nullptr : recurrent_to_input_weights->data.f;
- const float* input_gate_bias_ptr =
- (use_cifg) ? nullptr : input_gate_bias->data.f;
- const float* cell_to_input_weights_ptr =
- (use_peephole && !use_cifg) ? cell_to_input_weights->data.f : nullptr;
- const float* cell_to_forget_weights_ptr =
- (use_peephole) ? cell_to_forget_weights->data.f : nullptr;
- const float* cell_to_output_weights_ptr =
- (use_peephole) ? cell_to_output_weights->data.f : nullptr;
- const float* projection_weights_ptr =
- (projection_weights == nullptr) ? nullptr : projection_weights->data.f;
- const float* projection_bias_ptr =
- (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
- const float* input_layer_norm_weight_ptr =
- (input_layer_norm_weights == nullptr) ? nullptr
- : input_layer_norm_weights->data.f;
-
- // Required tensors, pointers are non-null.
- const float* input_ptr_batch = input->data.f;
- const float* input_to_forget_weights_ptr = input_to_forget_weights->data.f;
- const float* input_to_cell_weights_ptr = input_to_cell_weights->data.f;
- const float* input_to_output_weights_ptr = input_to_output_weights->data.f;
- const float* recurrent_to_forget_weights_ptr =
- recurrent_to_forget_weights->data.f;
- const float* recurrent_to_cell_weights_ptr =
- recurrent_to_cell_weights->data.f;
- const float* recurrent_to_output_weights_ptr =
- recurrent_to_output_weights->data.f;
- const float* forget_layer_norm_weight_ptr = forget_layer_norm_weights->data.f;
- const float* cell_layer_norm_weight_ptr = cell_layer_norm_weights->data.f;
- const float* output_layer_norm_weight_ptr = output_layer_norm_weights->data.f;
- const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
- const float* cell_bias_ptr = cell_bias->data.f;
- const float* output_gate_bias_ptr = output_gate_bias->data.f;
-
- float* activation_state_ptr = activation_state->data.f;
- float* cell_state_ptr = cell_state->data.f;
- float* output_ptr_batch = output->data.f;
-
- LayerNormLstmStep(
- input_ptr_batch, input_to_input_weights_ptr, input_to_forget_weights_ptr,
- input_to_cell_weights_ptr, input_to_output_weights_ptr,
- recurrent_to_input_weights_ptr, recurrent_to_forget_weights_ptr,
- recurrent_to_cell_weights_ptr, recurrent_to_output_weights_ptr,
- cell_to_input_weights_ptr, cell_to_forget_weights_ptr,
- cell_to_output_weights_ptr, input_layer_norm_weight_ptr,
- forget_layer_norm_weight_ptr, cell_layer_norm_weight_ptr,
- output_layer_norm_weight_ptr, input_gate_bias_ptr, forget_gate_bias_ptr,
- cell_bias_ptr, output_gate_bias_ptr, projection_weights_ptr,
- projection_bias_ptr, cell_clip, proj_clip, activation, n_batch, n_cell,
- n_input, n_output, activation_state_ptr, cell_state_ptr,
- input_gate_scratch, forget_gate_scratch, cell_scratch,
- output_gate_scratch, output_ptr_batch);
-
- return kTfLiteOk;
-}
-
-TfLiteStatus EvalHybrid(
- const TfLiteTensor* input, const TfLiteTensor* input_to_input_weights,
- const TfLiteTensor* input_to_forget_weights,
- const TfLiteTensor* input_to_cell_weights,
- const TfLiteTensor* input_to_output_weights,
- const TfLiteTensor* recurrent_to_input_weights,
- const TfLiteTensor* recurrent_to_forget_weights,
- const TfLiteTensor* recurrent_to_cell_weights,
- const TfLiteTensor* recurrent_to_output_weights,
- const TfLiteTensor* cell_to_input_weights,
- const TfLiteTensor* cell_to_forget_weights,
- const TfLiteTensor* cell_to_output_weights,
- const TfLiteTensor* input_layer_norm_weights,
- const TfLiteTensor* forget_layer_norm_weights,
- const TfLiteTensor* cell_layer_norm_weights,
- const TfLiteTensor* output_layer_norm_weights,
- const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
- const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
- const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
- float cell_clip, float proj_clip, const TfLiteFusedActivation& activation,
- TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors,
- TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_weights,
- TfLiteTensor* input_quantized, TfLiteTensor* activation_state_quantized,
- TfLiteTensor* cell_state_quantized, TfLiteTensor* activation_state,
- TfLiteTensor* cell_state, TfLiteTensor* output) {
- const int n_batch = input->dims->data[0];
- const int n_input = input->dims->data[1];
- // n_cell and n_output will be the same size when there is no projection.
- const int n_cell = input_to_output_weights->dims->data[0];
- const int n_output = recurrent_to_output_weights->dims->data[1];
-
- // Since we have already checked that weights are all there or none, we can
- // check the existence of only one to get the condition.
- const bool use_cifg = (input_to_input_weights == nullptr);
- const bool use_peephole = (cell_to_output_weights != nullptr);
-
- float* input_gate_scratch = nullptr;
- float* cell_scratch = nullptr;
- float* forget_gate_scratch = nullptr;
- float* output_gate_scratch = nullptr;
- if (use_cifg) {
- cell_scratch = scratch_buffer->data.f;
- forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- } else {
- input_gate_scratch = scratch_buffer->data.f;
- cell_scratch = scratch_buffer->data.f + n_cell * n_batch;
- forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch;
- output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch;
- }
-
- // Check optional tensors, the respective pointers can be null.
- int8_t* input_to_input_weights_ptr = nullptr;
- float input_to_input_weights_scale = 1.0f;
- int8_t* recurrent_to_input_weights_ptr = nullptr;
- float recurrent_to_input_weights_scale = 1.0f;
- float* input_gate_bias_ptr = nullptr;
- if (!use_cifg) {
- input_to_input_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_input_weights->data.uint8);
- recurrent_to_input_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_input_weights->data.uint8);
- input_gate_bias_ptr = input_gate_bias->data.f;
- input_to_input_weights_scale = input_to_input_weights->params.scale;
- recurrent_to_input_weights_scale = recurrent_to_input_weights->params.scale;
- }
-
- int8_t* cell_to_input_weights_ptr = nullptr;
- int8_t* cell_to_forget_weights_ptr = nullptr;
- int8_t* cell_to_output_weights_ptr = nullptr;
- float cell_to_input_weights_scale = 1.0f;
- float cell_to_forget_weights_scale = 1.0f;
- float cell_to_output_weights_scale = 1.0f;
- if (use_peephole) {
- if (!use_cifg) {
- cell_to_input_weights_ptr =
- reinterpret_cast<int8_t*>(cell_to_input_weights->data.uint8);
- cell_to_input_weights_scale = cell_to_input_weights->params.scale;
- }
- cell_to_forget_weights_ptr =
- reinterpret_cast<int8_t*>(cell_to_forget_weights->data.uint8);
- cell_to_output_weights_ptr =
- reinterpret_cast<int8_t*>(cell_to_output_weights->data.uint8);
- cell_to_forget_weights_scale = cell_to_forget_weights->params.scale;
- cell_to_output_weights_scale = cell_to_output_weights->params.scale;
- }
-
- const int8_t* projection_weights_ptr =
- (projection_weights == nullptr)
- ? nullptr
- : reinterpret_cast<int8_t*>(projection_weights->data.uint8);
- const float projection_weights_scale =
- (projection_weights == nullptr) ? 1.0f : projection_weights->params.scale;
- const float* projection_bias_ptr =
- (projection_bias == nullptr) ? nullptr : projection_bias->data.f;
- const float* input_layer_norm_weight_ptr =
- (input_layer_norm_weights == nullptr) ? nullptr
- : input_layer_norm_weights->data.f;
-
- // Required tensors, pointers are non-null.
- const float* input_ptr_batch = input->data.f;
- const int8_t* input_to_forget_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_forget_weights->data.uint8);
- const float input_to_forget_weights_scale =
- input_to_forget_weights->params.scale;
- const int8_t* input_to_cell_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_cell_weights->data.uint8);
- const float input_to_cell_weights_scale = input_to_cell_weights->params.scale;
- const int8_t* input_to_output_weights_ptr =
- reinterpret_cast<int8_t*>(input_to_output_weights->data.uint8);
- const float input_to_output_weights_scale =
- input_to_output_weights->params.scale;
- const int8_t* recurrent_to_forget_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_forget_weights->data.uint8);
- const float recurrent_to_forget_weights_scale =
- recurrent_to_forget_weights->params.scale;
- const int8_t* recurrent_to_cell_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_cell_weights->data.uint8);
- const float recurrent_to_cell_weights_scale =
- recurrent_to_cell_weights->params.scale;
- const int8_t* recurrent_to_output_weights_ptr =
- reinterpret_cast<int8_t*>(recurrent_to_output_weights->data.uint8);
- const float recurrent_to_output_weights_scale =
- recurrent_to_output_weights->params.scale;
- const float* forget_layer_norm_weight_ptr = forget_layer_norm_weights->data.f;
- const float* cell_layer_norm_weight_ptr = cell_layer_norm_weights->data.f;
- const float* output_layer_norm_weight_ptr = output_layer_norm_weights->data.f;
- const float* forget_gate_bias_ptr = forget_gate_bias->data.f;
- const float* cell_bias_ptr = cell_bias->data.f;
- const float* output_gate_bias_ptr = output_gate_bias->data.f;
-
- float* activation_state_ptr = activation_state->data.f;
- float* cell_state_ptr = cell_state->data.f;
- float* output_ptr_batch = output->data.f;
-
- // Temporary storage for quantized values and scaling factors.
- int8_t* quantized_input_ptr =
- reinterpret_cast<int8_t*>(input_quantized->data.uint8);
- int8_t* quantized_activation_state_ptr =
- reinterpret_cast<int8_t*>(activation_state_quantized->data.uint8);
- int8_t* quantized_cell_state_ptr =
- reinterpret_cast<int8_t*>(cell_state_quantized->data.uint8);
- float* scaling_factors_ptr = scaling_factors->data.f;
- float* prod_scaling_factors_ptr = prod_scaling_factors->data.f;
- float* recovered_weights_ptr = recovered_weights->data.f;
-
- LayerNormLstmStep(
- input_ptr_batch, input_to_input_weights_ptr, input_to_input_weights_scale,
- input_to_forget_weights_ptr, input_to_forget_weights_scale,
- input_to_cell_weights_ptr, input_to_cell_weights_scale,
- input_to_output_weights_ptr, input_to_output_weights_scale,
- recurrent_to_input_weights_ptr, recurrent_to_input_weights_scale,
- recurrent_to_forget_weights_ptr, recurrent_to_forget_weights_scale,
- recurrent_to_cell_weights_ptr, recurrent_to_cell_weights_scale,
- recurrent_to_output_weights_ptr, recurrent_to_output_weights_scale,
- cell_to_input_weights_ptr, cell_to_input_weights_scale,
- cell_to_forget_weights_ptr, cell_to_forget_weights_scale,
- cell_to_output_weights_ptr, cell_to_output_weights_scale,
- input_layer_norm_weight_ptr, forget_layer_norm_weight_ptr,
- cell_layer_norm_weight_ptr, output_layer_norm_weight_ptr,
- input_gate_bias_ptr, forget_gate_bias_ptr, cell_bias_ptr,
- output_gate_bias_ptr, projection_weights_ptr, projection_weights_scale,
- projection_bias_ptr, cell_clip, proj_clip, activation, n_batch, n_cell,
- n_input, n_output, input_gate_scratch, forget_gate_scratch, cell_scratch,
- output_gate_scratch, scaling_factors_ptr, prod_scaling_factors_ptr,
- recovered_weights_ptr, quantized_input_ptr,
- quantized_activation_state_ptr, quantized_cell_state_ptr,
- activation_state_ptr, cell_state_ptr, output_ptr_batch);
-
- return kTfLiteOk;
-}
-
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- const OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
-
- const TfLiteTensor* input = GetInput(context, node, kInputTensor);
-
- const TfLiteTensor* input_to_input_weights =
- GetOptionalInputTensor(context, node, kInputToInputWeightsTensor);
- const TfLiteTensor* input_to_forget_weights =
- GetInput(context, node, kInputToForgetWeightsTensor);
- const TfLiteTensor* input_to_cell_weights =
- GetInput(context, node, kInputToCellWeightsTensor);
- const TfLiteTensor* input_to_output_weights =
- GetInput(context, node, kInputToOutputWeightsTensor);
-
- const TfLiteTensor* recurrent_to_input_weights =
- GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor);
- const TfLiteTensor* recurrent_to_forget_weights =
- GetInput(context, node, kRecurrentToForgetWeightsTensor);
- const TfLiteTensor* recurrent_to_cell_weights =
- GetInput(context, node, kRecurrentToCellWeightsTensor);
- const TfLiteTensor* recurrent_to_output_weights =
- GetInput(context, node, kRecurrentToOutputWeightsTensor);
-
- const TfLiteTensor* cell_to_input_weights =
- GetOptionalInputTensor(context, node, kCellToInputWeightsTensor);
- const TfLiteTensor* cell_to_forget_weights =
- GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor);
- const TfLiteTensor* cell_to_output_weights =
- GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor);
-
- const TfLiteTensor* input_layer_norm_weights =
- GetOptionalInputTensor(context, node, kInputLayerNormWeightsTensor);
- const TfLiteTensor* forget_layer_norm_weights =
- GetInput(context, node, kForgetLayerNormWeightsTensor);
- const TfLiteTensor* cell_layer_norm_weights =
- GetInput(context, node, kCellLayerNormWeightsTensor);
- const TfLiteTensor* output_layer_norm_weights =
- GetInput(context, node, kOutputLayerNormWeightsTensor);
-
- const TfLiteTensor* input_gate_bias =
- GetOptionalInputTensor(context, node, kInputGateBiasTensor);
- const TfLiteTensor* forget_gate_bias =
- GetInput(context, node, kForgetGateBiasTensor);
- const TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor);
- const TfLiteTensor* output_gate_bias =
- GetInput(context, node, kOutputGateBiasTensor);
-
- const TfLiteTensor* projection_weights =
- GetOptionalInputTensor(context, node, kProjectionWeightsTensor);
- const TfLiteTensor* projection_bias =
- GetOptionalInputTensor(context, node, kProjectionBiasTensor);
-
- // Index the scratch buffers pointers to the global scratch buffer.
- TfLiteTensor* scratch_buffer = GetTemporary(context, node, /*index=*/0);
-
- TfLiteTensor* activation_state =
- &context->tensors[node->inputs->data[kInputActivationStateTensor]];
- TfLiteTensor* cell_state =
- &context->tensors[node->inputs->data[kInputCellStateTensor]];
-
- TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
-
- switch (input_to_output_weights->type) {
- case kTfLiteFloat32: {
- return EvalFloat(input, input_to_input_weights, input_to_forget_weights,
- input_to_cell_weights, input_to_output_weights,
- recurrent_to_input_weights, recurrent_to_forget_weights,
- recurrent_to_cell_weights, recurrent_to_output_weights,
- cell_to_input_weights, cell_to_forget_weights,
- cell_to_output_weights, input_layer_norm_weights,
- forget_layer_norm_weights, cell_layer_norm_weights,
- output_layer_norm_weights, input_gate_bias,
- forget_gate_bias, cell_bias, output_gate_bias,
- projection_weights, projection_bias, op_data->cell_clip,
- op_data->proj_clip, op_data->activation, scratch_buffer,
- activation_state, cell_state, output);
- }
- case kTfLiteUInt8: {
- TfLiteTensor* input_quantized = GetTemporary(context, node, /*index=*/1);
- TfLiteTensor* activation_state_quantized =
- GetTemporary(context, node, /*index=*/2);
- TfLiteTensor* cell_state_quantized =
- GetTemporary(context, node, /*index=*/3);
- TfLiteTensor* scaling_factors = GetTemporary(context, node, /*index=*/4);
- TfLiteTensor* prod_scaling_factors =
- GetTemporary(context, node, /*index=*/5);
- TfLiteTensor* recovered_weights =
- GetTemporary(context, node, /*index=*/6);
- return EvalHybrid(
- input, input_to_input_weights, input_to_forget_weights,
- input_to_cell_weights, input_to_output_weights,
- recurrent_to_input_weights, recurrent_to_forget_weights,
- recurrent_to_cell_weights, recurrent_to_output_weights,
- cell_to_input_weights, cell_to_forget_weights, cell_to_output_weights,
- input_layer_norm_weights, forget_layer_norm_weights,
- cell_layer_norm_weights, output_layer_norm_weights, input_gate_bias,
- forget_gate_bias, cell_bias, output_gate_bias, projection_weights,
- projection_bias, op_data->cell_clip, op_data->proj_clip,
- op_data->activation, scratch_buffer, scaling_factors,
- prod_scaling_factors, recovered_weights, input_quantized,
- activation_state_quantized, cell_state_quantized, activation_state,
- cell_state, output);
- }
- default:
- context->ReportError(context, "Type %d is not currently supported.",
- input_to_output_weights->type);
- return kTfLiteError;
- }
- return kTfLiteOk;
-}
-
-void Free(TfLiteContext* context, void* buffer) {
- delete reinterpret_cast<OpData*>(buffer);
-}
-
-} // namespace layer_norm_lstm
-
-TfLiteRegistration* Register_LAYER_NORM_LSTM() {
- static TfLiteRegistration r = {layer_norm_lstm::Init, layer_norm_lstm::Free,
- layer_norm_lstm::Prepare,
- layer_norm_lstm::Eval};
- return &r;
-}
-
-} // namespace custom
-} // namespace ops
-} // namespace tflite
diff --git a/tensorflow/lite/kernels/layer_norm_lstm_test.cc b/tensorflow/lite/kernels/layer_norm_lstm_test.cc
deleted file mode 100644
index 5aed818..0000000
--- a/tensorflow/lite/kernels/layer_norm_lstm_test.cc
+++ /dev/null
@@ -1,885 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-// Unit test for TFLite Layer Norm LSTM op.
-
-#include <memory>
-#include <vector>
-
-#include <gmock/gmock.h>
-#include <gtest/gtest.h>
-#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
-#include "tensorflow/lite/interpreter.h"
-#include "tensorflow/lite/kernels/register.h"
-#include "tensorflow/lite/kernels/test_util.h"
-#include "tensorflow/lite/model.h"
-
-namespace tflite {
-namespace ops {
-namespace custom {
-
-TfLiteRegistration* Register_LAYER_NORM_LSTM();
-
-namespace {
-
-using ::testing::ElementsAreArray;
-
-class LayerNormLSTMOpModel : public SingleOpModel {
- public:
- LayerNormLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output,
- bool use_cifg, bool use_peephole,
- bool use_projection_weights, bool use_projection_bias,
- float cell_clip, float proj_clip,
- const std::vector<std::vector<int>>& input_shapes,
- const TensorType& weight_type = TensorType_FLOAT32)
- : n_batch_(n_batch),
- n_input_(n_input),
- n_cell_(n_cell),
- n_output_(n_output) {
- input_ = AddInput(TensorType_FLOAT32);
-
- if (use_cifg) {
- input_to_input_weights_ = AddNullInput();
- } else {
- input_to_input_weights_ = AddInput(weight_type);
- }
-
- input_to_forget_weights_ = AddInput(weight_type);
- input_to_cell_weights_ = AddInput(weight_type);
- input_to_output_weights_ = AddInput(weight_type);
-
- if (use_cifg) {
- recurrent_to_input_weights_ = AddNullInput();
- } else {
- recurrent_to_input_weights_ = AddInput(weight_type);
- }
-
- recurrent_to_forget_weights_ = AddInput(weight_type);
- recurrent_to_cell_weights_ = AddInput(weight_type);
- recurrent_to_output_weights_ = AddInput(weight_type);
-
- if (use_peephole) {
- if (use_cifg) {
- cell_to_input_weights_ = AddNullInput();
- } else {
- cell_to_input_weights_ = AddInput(weight_type);
- }
- cell_to_forget_weights_ = AddInput(weight_type);
- cell_to_output_weights_ = AddInput(weight_type);
- } else {
- cell_to_input_weights_ = AddNullInput();
- cell_to_forget_weights_ = AddNullInput();
- cell_to_output_weights_ = AddNullInput();
- }
-
- if (use_cifg) {
- input_layer_norm_weights_ = AddNullInput();
- } else {
- input_layer_norm_weights_ = AddInput(TensorType_FLOAT32);
- }
- forget_layer_norm_weights_ = AddInput(TensorType_FLOAT32);
- cell_layer_norm_weights_ = AddInput(TensorType_FLOAT32);
- output_layer_norm_weights_ = AddInput(TensorType_FLOAT32);
-
- if (use_cifg) {
- input_gate_bias_ = AddNullInput();
- } else {
- input_gate_bias_ = AddInput(TensorType_FLOAT32);
- }
- forget_gate_bias_ = AddInput(TensorType_FLOAT32);
- cell_bias_ = AddInput(TensorType_FLOAT32);
- output_gate_bias_ = AddInput(TensorType_FLOAT32);
-
- if (use_projection_weights) {
- projection_weights_ = AddInput(weight_type);
- if (use_projection_bias) {
- projection_bias_ = AddInput(TensorType_FLOAT32);
- } else {
- projection_bias_ = AddNullInput();
- }
- } else {
- projection_weights_ = AddNullInput();
- projection_bias_ = AddNullInput();
- }
-
- // Adding the 2 state tensors.
- output_state_ =
- AddInput(TensorData{TensorType_FLOAT32, {n_output_ * n_batch_}}, true);
- cell_state_ =
- AddInput(TensorData{TensorType_FLOAT32, {n_cell_ * n_batch_}}, true);
-
- output_ = AddOutput(TensorType_FLOAT32);
-
- // Set up and pass in custom options using flexbuffer.
- flexbuffers::Builder fbb;
- fbb.Map([&]() {
- fbb.Int("cell_clip", cell_clip);
- fbb.Int("proj_clip", proj_clip);
- fbb.String("fused_activation_function", "TANH");
- });
- fbb.Finish();
- SetCustomOp("LAYER_NORM_LSTM", fbb.GetBuffer(), Register_LAYER_NORM_LSTM);
- BuildInterpreter(input_shapes);
- }
-
- void SetInputToInputWeights(const std::vector<float>& f) {
- PopulateTensor(input_to_input_weights_, f);
- }
-
- void SetInputToForgetWeights(const std::vector<float>& f) {
- PopulateTensor(input_to_forget_weights_, f);
- }
-
- void SetInputToCellWeights(const std::vector<float>& f) {
- PopulateTensor(input_to_cell_weights_, f);
- }
-
- void SetInputToOutputWeights(const std::vector<float>& f) {
- PopulateTensor(input_to_output_weights_, f);
- }
-
- void SetRecurrentToInputWeights(const std::vector<float>& f) {
- PopulateTensor(recurrent_to_input_weights_, f);
- }
-
- void SetRecurrentToForgetWeights(const std::vector<float>& f) {
- PopulateTensor(recurrent_to_forget_weights_, f);
- }
-
- void SetRecurrentToCellWeights(const std::vector<float>& f) {
- PopulateTensor(recurrent_to_cell_weights_, f);
- }
-
- void SetRecurrentToOutputWeights(const std::vector<float>& f) {
- PopulateTensor(recurrent_to_output_weights_, f);
- }
-
- void SetCellToInputWeights(const std::vector<float>& f) {
- PopulateTensor(cell_to_input_weights_, f);
- }
-
- void SetCellToForgetWeights(const std::vector<float>& f) {
- PopulateTensor(cell_to_forget_weights_, f);
- }
-
- void SetCellToOutputWeights(const std::vector<float>& f) {
- PopulateTensor(cell_to_output_weights_, f);
- }
-
- void SetInputLayerNormWeights(const std::vector<float>& f) {
- PopulateTensor(input_layer_norm_weights_, f);
- }
-
- void SetForgetLayerNormWeights(const std::vector<float>& f) {
- PopulateTensor(forget_layer_norm_weights_, f);
- }
-
- void SetCellLayerNormWeights(const std::vector<float>& f) {
- PopulateTensor(cell_layer_norm_weights_, f);
- }
-
- void SetOutputLayerNormWeights(const std::vector<float>& f) {
- PopulateTensor(output_layer_norm_weights_, f);
- }
-
- void SetInputGateBias(const std::vector<float>& f) {
- PopulateTensor(input_gate_bias_, f);
- }
-
- void SetForgetGateBias(const std::vector<float>& f) {
- PopulateTensor(forget_gate_bias_, f);
- }
-
- void SetCellBias(const std::vector<float>& f) {
- PopulateTensor(cell_bias_, f);
- }
-
- void SetOutputGateBias(const std::vector<float>& f) {
- PopulateTensor(output_gate_bias_, f);
- }
-
- void SetProjectionWeights(const std::vector<float>& f) {
- PopulateTensor(projection_weights_, f);
- }
-
- void SetProjectionBias(const std::vector<float>& f) {
- PopulateTensor(projection_bias_, f);
- }
-
- void SetInput(int offset, const float* begin, const float* end) {
- PopulateTensor(input_, offset, const_cast<float*>(begin),
- const_cast<float*>(end));
- }
-
- std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
-
- int num_inputs() { return n_input_; }
- int num_outputs() { return n_output_; }
- int num_cells() { return n_cell_; }
- int num_batches() { return n_batch_; }
-
- protected:
- int input_;
- int input_to_input_weights_;
- int input_to_forget_weights_;
- int input_to_cell_weights_;
- int input_to_output_weights_;
-
- int recurrent_to_input_weights_;
- int recurrent_to_forget_weights_;
- int recurrent_to_cell_weights_;
- int recurrent_to_output_weights_;
-
- int cell_to_input_weights_;
- int cell_to_forget_weights_;
- int cell_to_output_weights_;
-
- int input_layer_norm_weights_;
- int forget_layer_norm_weights_;
- int cell_layer_norm_weights_;
- int output_layer_norm_weights_;
-
- int input_gate_bias_;
- int forget_gate_bias_;
- int cell_bias_;
- int output_gate_bias_;
-
- int projection_weights_;
- int projection_bias_;
-
- int output_state_;
- int cell_state_;
-
- int output_;
-
- int n_batch_;
- int n_input_;
- int n_cell_;
- int n_output_;
-};
-
-class HybridLayerNormLSTMOpModel : public LayerNormLSTMOpModel {
- public:
- HybridLayerNormLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output,
- bool use_cifg, bool use_peephole,
- bool use_projection_weights,
- bool use_projection_bias, float cell_clip,
- float proj_clip,
- const std::vector<std::vector<int>>& input_shapes)
- : LayerNormLSTMOpModel(n_batch, n_input, n_cell, n_output, use_cifg,
- use_peephole, use_projection_weights,
- use_projection_bias, cell_clip, proj_clip,
- input_shapes, TensorType_UINT8) {}
-
- void SetInputToInputWeights(const std::vector<float>& f) {
- SymmetricQuantizeAndPopulate(input_to_input_weights_, f);
- }
-
- void SetInputToForgetWeights(const std::vector<float>& f) {
- SymmetricQuantizeAndPopulate(input_to_forget_weights_, f);
- }
-
- void SetInputToCellWeights(const std::vector<float>& f) {
- SymmetricQuantizeAndPopulate(input_to_cell_weights_, f);
- }
-
- void SetInputToOutputWeights(const std::vector<float>& f) {
- SymmetricQuantizeAndPopulate(input_to_output_weights_, f);
- }
-
- void SetRecurrentToInputWeights(const std::vector<float>& f) {
- SymmetricQuantizeAndPopulate(recurrent_to_input_weights_, f);
- }
-
- void SetRecurrentToForgetWeights(const std::vector<float>& f) {
- SymmetricQuantizeAndPopulate(recurrent_to_forget_weights_, f);
- }
-
- void SetRecurrentToCellWeights(const std::vector<float>& f) {
- SymmetricQuantizeAndPopulate(recurrent_to_cell_weights_, f);
- }
-
- void SetRecurrentToOutputWeights(const std::vector<float>& f) {
- SymmetricQuantizeAndPopulate(recurrent_to_output_weights_, f);
- }
-
- void SetCellToInputWeights(const std::vector<float>& f) {
- SymmetricQuantizeAndPopulate(cell_to_input_weights_, f);
- }
-
- void SetCellToForgetWeights(const std::vector<float>& f) {
- SymmetricQuantizeAndPopulate(cell_to_forget_weights_, f);
- }
-
- void SetCellToOutputWeights(const std::vector<float>& f) {
- SymmetricQuantizeAndPopulate(cell_to_output_weights_, f);
- }
-
- void SetInputLayerNormWeights(const std::vector<float>& f) {
- PopulateTensor(input_layer_norm_weights_, f);
- }
-
- void SetForgetLayerNormWeights(const std::vector<float>& f) {
- PopulateTensor(forget_layer_norm_weights_, f);
- }
-
- void SetCellLayerNormWeights(const std::vector<float>& f) {
- PopulateTensor(cell_layer_norm_weights_, f);
- }
-
- void SetOutputLayerNormWeights(const std::vector<float>& f) {
- PopulateTensor(output_layer_norm_weights_, f);
- }
-
- void SetProjectionWeights(const std::vector<float>& f) {
- SymmetricQuantizeAndPopulate(projection_weights_, f);
- }
-};
-
-class BaseLayerNormLstmTest : public ::testing::Test {
- protected:
- // Weights of the Layer Norm LSTM model. Some are optional.
- std::vector<float> input_to_input_weights_;
- std::vector<float> input_to_cell_weights_;
- std::vector<float> input_to_forget_weights_;
- std::vector<float> input_to_output_weights_;
- std::vector<float> input_gate_bias_;
- std::vector<float> cell_gate_bias_;
- std::vector<float> forget_gate_bias_;
- std::vector<float> output_gate_bias_;
- std::vector<float> recurrent_to_input_weights_;
- std::vector<float> recurrent_to_cell_weights_;
- std::vector<float> recurrent_to_forget_weights_;
- std::vector<float> recurrent_to_output_weights_;
- std::vector<float> cell_to_input_weights_;
- std::vector<float> cell_to_forget_weights_;
- std::vector<float> cell_to_output_weights_;
- std::vector<float> input_layer_norm_weights_;
- std::vector<float> forget_layer_norm_weights_;
- std::vector<float> cell_layer_norm_weights_;
- std::vector<float> output_layer_norm_weights_;
- std::vector<float> projection_weights_;
-
- // Layer Norm LSTM input is stored as num_batch x num_inputs vector.
- std::vector<std::vector<float>> layer_norm_lstm_input_;
-
- // Compares output up to tolerance to the result of the layer_norm_lstm given
- // the input.
- void VerifyGoldens(const std::vector<std::vector<float>>& input,
- const std::vector<std::vector<float>>& output,
- LayerNormLSTMOpModel* layer_norm_lstm,
- float tolerance = 1e-5) {
- const int num_batches = input.size();
- EXPECT_GT(num_batches, 0);
- const int num_inputs = layer_norm_lstm->num_inputs();
- EXPECT_GT(num_inputs, 0);
- const int input_sequence_size = input[0].size() / num_inputs;
- EXPECT_GT(input_sequence_size, 0);
- for (int i = 0; i < input_sequence_size; ++i) {
- for (int b = 0; b < num_batches; ++b) {
- const float* batch_start = input[b].data() + i * num_inputs;
- const float* batch_end = batch_start + num_inputs;
-
- layer_norm_lstm->SetInput(b * layer_norm_lstm->num_inputs(),
- batch_start, batch_end);
- }
-
- layer_norm_lstm->Invoke();
-
- const int num_outputs = layer_norm_lstm->num_outputs();
- std::vector<float> expected;
- for (int b = 0; b < num_batches; ++b) {
- const float* golden_start_batch = output[b].data() + i * num_outputs;
- const float* golden_end_batch = golden_start_batch + num_outputs;
- expected.insert(expected.end(), golden_start_batch, golden_end_batch);
- }
- EXPECT_THAT(layer_norm_lstm->GetOutput(),
- ElementsAreArray(ArrayFloatNear(expected, tolerance)));
- }
- }
-};
-
-class NoCifgPeepholeProjectionNoClippingLayerNormLstmTest
- : public BaseLayerNormLstmTest {
- void SetUp() override {
- input_to_input_weights_ = {0.5, 0.6, 0.7, -0.8, -0.9, 0.1, 0.2,
- 0.3, -0.4, 0.5, -0.8, 0.7, -0.6, 0.5,
- -0.4, -0.5, -0.4, -0.3, -0.2, -0.1};
-
- input_to_forget_weights_ = {-0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2,
- -0.4, 0.3, -0.8, -0.4, 0.3, -0.5, -0.4,
- -0.6, 0.3, -0.4, -0.6, -0.5, -0.5};
-
- input_to_cell_weights_ = {-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2,
- -0.3, -0.2, -0.6, 0.6, -0.1, -0.4, -0.3,
- -0.7, 0.7, -0.9, -0.5, 0.8, 0.6};
-
- input_to_output_weights_ = {-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3,
- -0.3, -0.8, -0.2, 0.6, -0.2, 0.4, -0.7,
- -0.3, -0.5, 0.1, 0.5, -0.6, -0.4};
-
- input_gate_bias_ = {0.03, 0.15, 0.22, 0.38};
-
- forget_gate_bias_ = {0.1, -0.3, -0.2, 0.1};
-
- cell_gate_bias_ = {-0.05, 0.72, 0.25, 0.08};
-
- output_gate_bias_ = {0.05, -0.01, 0.2, 0.1};
-
- recurrent_to_input_weights_ = {-0.2, -0.3, 0.4, 0.1, -0.5, 0.9,
- -0.2, -0.3, -0.7, 0.05, -0.2, -0.6};
-
- recurrent_to_cell_weights_ = {-0.3, 0.2, 0.1, -0.3, 0.8, -0.08,
- -0.2, 0.3, 0.8, -0.6, -0.1, 0.2};
-
- recurrent_to_forget_weights_ = {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4,
- 0.9, 0.3, -0.1, 0.2, 0.5, 0.2};
-
- recurrent_to_output_weights_ = {0.3, -0.1, 0.1, -0.2, -0.5, -0.7,
- -0.2, -0.6, -0.1, -0.4, -0.7, -0.2};
-
- cell_to_input_weights_ = {0.05, 0.1, 0.25, 0.15};
-
- cell_to_forget_weights_ = {-0.02, -0.15, -0.25, -0.03};
-
- cell_to_output_weights_ = {0.1, -0.1, -0.5, 0.05};
-
- input_layer_norm_weights_ = {0.1, 0.2, 0.3, 0.5};
- forget_layer_norm_weights_ = {0.2, 0.2, 0.4, 0.3};
- cell_layer_norm_weights_ = {0.7, 0.2, 0.3, 0.8};
- output_layer_norm_weights_ = {0.6, 0.2, 0.2, 0.5};
-
- projection_weights_ = {-0.1, 0.2, 0.01, -0.2, 0.1, 0.5,
- 0.3, 0.08, 0.07, 0.2, -0.4, 0.2};
-
- layer_norm_lstm_input_ = {
- {// Batch0: 3 (input_sequence_size) * 5 (n_input)
- 0.7, 0.8, 0.1, 0.2, 0.3, // seq 0
- 0.8, 0.1, 0.2, 0.4, 0.5, // seq 1
- 0.2, 0.7, 0.7, 0.1, 0.7}, // seq 2
-
- {// Batch1: 3 (input_sequence_size) * 5 (n_input)
- 0.3, 0.2, 0.9, 0.8, 0.1, // seq 0
- 0.1, 0.5, 0.2, 0.4, 0.2, // seq 1
- 0.6, 0.9, 0.2, 0.5, 0.7}, // seq 2
- };
- }
-};
-
-TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
- LayerNormLstmBlackBoxTest) {
- const int n_batch = 2;
- const int n_input = 5;
- const int n_cell = 4;
- const int n_output = 3;
- const float ceil_clip = 0.0;
- const float proj_clip = 0.0;
-
- LayerNormLSTMOpModel layer_norm_lstm(
- n_batch, n_input, n_cell, n_output,
- /*use_cifg=*/false, /*use_peephole=*/true,
- /*use_projection_weights=*/true,
- /*use_projection_bias=*/false, ceil_clip, proj_clip,
- {
- {n_batch, n_input}, // input tensor
-
- {n_cell, n_input}, // input_to_input_weight tensor
- {n_cell, n_input}, // input_to_forget_weight tensor
- {n_cell, n_input}, // input_to_cell_weight tensor
- {n_cell, n_input}, // input_to_output_weight tensor
-
- {n_cell, n_output}, // recurrent_to_input_weight tensor
- {n_cell, n_output}, // recurrent_to_forget_weight tensor
- {n_cell, n_output}, // recurrent_to_cell_weight tensor
- {n_cell, n_output}, // recurrent_to_output_weight tensor
-
- {n_cell}, // cell_to_input_weight tensor
- {n_cell}, // cell_to_forget_weight tensor
- {n_cell}, // cell_to_output_weight tensor
-
- {n_cell}, // input_layer_norm_weight tensor
- {n_cell}, // forget_layer_norm_weight tensor
- {n_cell}, // cell_layer_norm_weight tensor
- {n_cell}, // output_layer_norm_weight tensor
-
- {n_cell}, // input_gate_bias tensor
- {n_cell}, // forget_gate_bias tensor
- {n_cell}, // cell_bias tensor
- {n_cell}, // output_gate_bias tensor
-
- {n_output, n_cell}, // projection_weight tensor
- {0}, // projection_bias tensor
- });
-
- layer_norm_lstm.SetInputToInputWeights(input_to_input_weights_);
- layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
- layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
- layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
-
- layer_norm_lstm.SetInputGateBias(input_gate_bias_);
- layer_norm_lstm.SetCellBias(cell_gate_bias_);
- layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
- layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
-
- layer_norm_lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
- layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
- layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
- layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
-
- layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_);
- layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
- layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
-
- layer_norm_lstm.SetInputLayerNormWeights(input_layer_norm_weights_);
- layer_norm_lstm.SetForgetLayerNormWeights(forget_layer_norm_weights_);
- layer_norm_lstm.SetCellLayerNormWeights(cell_layer_norm_weights_);
- layer_norm_lstm.SetOutputLayerNormWeights(output_layer_norm_weights_);
-
- layer_norm_lstm.SetProjectionWeights(projection_weights_);
-
- // Verify the final output.
- const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
- {
- // Batch0: 3 (input_sequence_size) * 3 (n_output)
- 0.0244077, 0.128027, -0.00170918, // seq 0
- 0.0137642, 0.140751, 0.0395835, // seq 1
- -0.00459231, 0.155278, 0.0837377, // seq 2
- },
- {
- // Batch1: 3 (input_sequence_size) * 3 (n_output)
- -0.00692428, 0.0848741, 0.063445, // seq 0
- -0.00403912, 0.139963, 0.072681, // seq 1
- 0.00752706, 0.161903, 0.0561371, // seq 2
- }};
-
- VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
- &layer_norm_lstm);
-}
-
-TEST_F(NoCifgPeepholeProjectionNoClippingLayerNormLstmTest,
- HybridLayerNormLstmBlackBoxTest) {
- const int n_batch = 2;
- const int n_input = 5;
- const int n_cell = 4;
- const int n_output = 3;
- const float ceil_clip = 0.0;
- const float proj_clip = 0.0;
-
- HybridLayerNormLSTMOpModel layer_norm_lstm(
- n_batch, n_input, n_cell, n_output,
- /*use_cifg=*/false, /*use_peephole=*/true,
- /*use_projection_weights=*/true,
- /*use_projection_bias=*/false, ceil_clip, proj_clip,
- {
- {n_batch, n_input}, // input tensor
-
- {n_cell, n_input}, // input_to_input_weight tensor
- {n_cell, n_input}, // input_to_forget_weight tensor
- {n_cell, n_input}, // input_to_cell_weight tensor
- {n_cell, n_input}, // input_to_output_weight tensor
-
- {n_cell, n_output}, // recurrent_to_input_weight tensor
- {n_cell, n_output}, // recurrent_to_forget_weight tensor
- {n_cell, n_output}, // recurrent_to_cell_weight tensor
- {n_cell, n_output}, // recurrent_to_output_weight tensor
-
- {n_cell}, // cell_to_input_weight tensor
- {n_cell}, // cell_to_forget_weight tensor
- {n_cell}, // cell_to_output_weight tensor
-
- {n_cell}, // input_layer_norm_weight tensor
- {n_cell}, // forget_layer_norm_weight tensor
- {n_cell}, // cell_layer_norm_weight tensor
- {n_cell}, // output_layer_norm_weight tensor
-
- {n_cell}, // input_gate_bias tensor
- {n_cell}, // forget_gate_bias tensor
- {n_cell}, // cell_bias tensor
- {n_cell}, // output_gate_bias tensor
-
- {n_output, n_cell}, // projection_weight tensor
- {0}, // projection_bias tensor
- });
-
- layer_norm_lstm.SetInputToInputWeights(input_to_input_weights_);
- layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
- layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
- layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
-
- layer_norm_lstm.SetInputGateBias(input_gate_bias_);
- layer_norm_lstm.SetCellBias(cell_gate_bias_);
- layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
- layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
-
- layer_norm_lstm.SetRecurrentToInputWeights(recurrent_to_input_weights_);
- layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
- layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
- layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
-
- layer_norm_lstm.SetCellToInputWeights(cell_to_input_weights_);
- layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
- layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
-
- layer_norm_lstm.SetInputLayerNormWeights(input_layer_norm_weights_);
- layer_norm_lstm.SetForgetLayerNormWeights(forget_layer_norm_weights_);
- layer_norm_lstm.SetCellLayerNormWeights(cell_layer_norm_weights_);
- layer_norm_lstm.SetOutputLayerNormWeights(output_layer_norm_weights_);
-
- layer_norm_lstm.SetProjectionWeights(projection_weights_);
-
- const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
- {
- // Batch0: 3 (input_sequence_size) * 3 (n_output)
- 0.0244576, 0.127847, -0.00181765, // seq 0
- 0.0137518, 0.140892, 0.0402234, // seq 1
- -0.0048839, 0.155096, 0.0840309, // seq 2
- },
- {
- // Batch1: 3 (input_sequence_size) * 3 (n_output)
- -0.00728636, 0.0843957, 0.0634786, // seq 0
- -0.00448382, 0.139278, 0.0737372, // seq 1
- 0.00734616, 0.161793, 0.0560238, // seq 2
- }};
-
- VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
- &layer_norm_lstm);
-}
-
-class CifgPeepholeProjectionNoClippingLayerNormLstmTest
- : public BaseLayerNormLstmTest {
- void SetUp() override {
- input_to_forget_weights_ = {-0.6, -0.1, 0.3, 0.2, 0.9, -0.5, -0.2,
- -0.4, 0.3, -0.8, -0.4, 0.3, -0.5, -0.4,
- -0.6, 0.3, -0.4, -0.6, -0.5, -0.5};
- input_to_cell_weights_ = {-0.4, -0.3, -0.2, -0.1, -0.5, 0.5, -0.2,
- -0.3, -0.2, -0.6, 0.6, -0.1, -0.4, -0.3,
- -0.7, 0.7, -0.9, -0.5, 0.8, 0.6};
- input_to_output_weights_ = {-0.8, -0.4, -0.2, -0.9, -0.1, -0.7, 0.3,
- -0.3, -0.8, -0.2, 0.6, -0.2, 0.4, -0.7,
- -0.3, -0.5, 0.1, 0.5, -0.6, -0.4};
-
- forget_gate_bias_ = {0.1, -0.3, -0.2, 0.1};
- cell_gate_bias_ = {-0.05, 0.72, 0.25, 0.08};
- output_gate_bias_ = {0.05, -0.01, 0.2, 0.1};
-
- recurrent_to_cell_weights_ = {-0.3, 0.2, 0.1, -0.3, 0.8, -0.08,
- -0.2, 0.3, 0.8, -0.6, -0.1, 0.2};
- recurrent_to_forget_weights_ = {-0.5, -0.3, -0.5, -0.2, 0.6, 0.4,
- 0.9, 0.3, -0.1, 0.2, 0.5, 0.2};
- recurrent_to_output_weights_ = {0.3, -0.1, 0.1, -0.2, -0.5, -0.7,
- -0.2, -0.6, -0.1, -0.4, -0.7, -0.2};
-
- cell_to_forget_weights_ = {-0.02, -0.15, -0.25, -0.03};
- cell_to_output_weights_ = {0.1, -0.1, -0.5, 0.05};
-
- forget_layer_norm_weights_ = {0.2, 0.2, 0.4, 0.3};
- cell_layer_norm_weights_ = {0.7, 0.2, 0.3, 0.8};
- output_layer_norm_weights_ = {0.6, 0.2, 0.2, 0.5};
- projection_weights_ = {-0.1, 0.2, 0.01, -0.2, 0.1, 0.5,
- 0.3, 0.08, 0.07, 0.2, -0.4, 0.2};
-
- layer_norm_lstm_input_ = {
- {// Batch0: 3 (input_sequence_size) * 5 (n_input)
- 0.7, 0.8, 0.1, 0.2, 0.3, // seq 0
- 0.8, 0.1, 0.2, 0.4, 0.5, // seq 1
- 0.2, 0.7, 0.7, 0.1, 0.7}, // seq 2
-
- {// Batch1: 3 (input_sequence_size) * 5 (n_input)
- 0.3, 0.2, 0.9, 0.8, 0.1, // seq 0
- 0.1, 0.5, 0.2, 0.4, 0.2, // seq 1
- 0.6, 0.9, 0.2, 0.5, 0.7}, // seq 2
- };
- }
-};
-
-TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
- LayerNormLstmBlackBoxTest) {
- const int n_batch = 2;
- const int n_input = 5;
- const int n_cell = 4;
- const int n_output = 3;
- const float ceil_clip = 0.0;
- const float proj_clip = 0.0;
-
- LayerNormLSTMOpModel layer_norm_lstm(
- n_batch, n_input, n_cell, n_output,
- /*use_cifg=*/true, /*use_peephole=*/true,
- /*use_projection_weights=*/true,
- /*use_projection_bias=*/false, ceil_clip, proj_clip,
- {
- {n_batch, n_input}, // input tensor
-
- {0, 0}, // input_to_input_weight tensor
- {n_cell, n_input}, // input_to_forget_weight tensor
- {n_cell, n_input}, // input_to_cell_weight tensor
- {n_cell, n_input}, // input_to_output_weight tensor
-
- {0, 0}, // recurrent_to_input_weight tensor
- {n_cell, n_output}, // recurrent_to_forget_weight tensor
- {n_cell, n_output}, // recurrent_to_cell_weight tensor
- {n_cell, n_output}, // recurrent_to_output_weight tensor
-
- {0}, // cell_to_input_weight tensor
- {n_cell}, // cell_to_forget_weight tensor
- {n_cell}, // cell_to_output_weight tensor
-
- {0}, // input_layer_norm_weight tensor
- {n_cell}, // forget_layer_norm_weight tensor
- {n_cell}, // cell_layer_norm_weight tensor
- {n_cell}, // output_layer_norm_weight tensor
-
- {0}, // input_gate_bias tensor
- {n_cell}, // forget_gate_bias tensor
- {n_cell}, // cell_bias tensor
- {n_cell}, // output_gate_bias tensor
-
- {n_output, n_cell}, // projection_weight tensor
- {0}, // projection_bias tensor
- });
-
- layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
- layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
- layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
-
- layer_norm_lstm.SetCellBias(cell_gate_bias_);
- layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
- layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
-
- layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
- layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
- layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
-
- layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
- layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
-
- layer_norm_lstm.SetForgetLayerNormWeights(forget_layer_norm_weights_);
- layer_norm_lstm.SetCellLayerNormWeights(cell_layer_norm_weights_);
- layer_norm_lstm.SetOutputLayerNormWeights(output_layer_norm_weights_);
-
- layer_norm_lstm.SetProjectionWeights(projection_weights_);
-
- // Verify the final output.
- const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
- {
- // Batch0: 3 (input_sequence_size) * 3 (n_output)
- 0.02129706, 0.140816242, 0.0112733059, // seq 0
- 0.0132302344, 0.152308047, 0.0346313119, // seq 1
- -0.0123688057, 0.165790111, 0.0893077999, // seq 2
- },
- {
- // Batch1: 3 (input_sequence_size) * 3 (n_output)
- -0.0226350538, 0.0916948169, 0.0769175813, // seq 0
- -0.0269966982, 0.149707705, 0.094149217, // seq 1
- -0.0103429332, 0.173016444, 0.0720508844, // seq 2
- }};
-
- VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
- &layer_norm_lstm);
-}
-
-TEST_F(CifgPeepholeProjectionNoClippingLayerNormLstmTest,
- HybridLayerNormLstmBlackBoxTest) {
- const int n_batch = 2;
- const int n_input = 5;
- const int n_cell = 4;
- const int n_output = 3;
- const float ceil_clip = 0.0;
- const float proj_clip = 0.0;
-
- HybridLayerNormLSTMOpModel layer_norm_lstm(
- n_batch, n_input, n_cell, n_output,
- /*use_cifg=*/true, /*use_peephole=*/true,
- /*use_projection_weights=*/true,
- /*use_projection_bias=*/false, ceil_clip, proj_clip,
- {
- {n_batch, n_input}, // input tensor
-
- {0, 0}, // input_to_input_weight tensor
- {n_cell, n_input}, // input_to_forget_weight tensor
- {n_cell, n_input}, // input_to_cell_weight tensor
- {n_cell, n_input}, // input_to_output_weight tensor
-
- {0, 0}, // recurrent_to_input_weight tensor
- {n_cell, n_output}, // recurrent_to_forget_weight tensor
- {n_cell, n_output}, // recurrent_to_cell_weight tensor
- {n_cell, n_output}, // recurrent_to_output_weight tensor
-
- {0}, // cell_to_input_weight tensor
- {n_cell}, // cell_to_forget_weight tensor
- {n_cell}, // cell_to_output_weight tensor
-
- {0}, // input_layer_norm_weight tensor
- {n_cell}, // forget_layer_norm_weight tensor
- {n_cell}, // cell_layer_norm_weight tensor
- {n_cell}, // output_layer_norm_weight tensor
-
- {0}, // input_gate_bias tensor
- {n_cell}, // forget_gate_bias tensor
- {n_cell}, // cell_bias tensor
- {n_cell}, // output_gate_bias tensor
-
- {n_output, n_cell}, // projection_weight tensor
- {0}, // projection_bias tensor
- });
-
- layer_norm_lstm.SetInputToCellWeights(input_to_cell_weights_);
- layer_norm_lstm.SetInputToForgetWeights(input_to_forget_weights_);
- layer_norm_lstm.SetInputToOutputWeights(input_to_output_weights_);
-
- layer_norm_lstm.SetCellBias(cell_gate_bias_);
- layer_norm_lstm.SetForgetGateBias(forget_gate_bias_);
- layer_norm_lstm.SetOutputGateBias(output_gate_bias_);
-
- layer_norm_lstm.SetRecurrentToCellWeights(recurrent_to_cell_weights_);
- layer_norm_lstm.SetRecurrentToForgetWeights(recurrent_to_forget_weights_);
- layer_norm_lstm.SetRecurrentToOutputWeights(recurrent_to_output_weights_);
-
- layer_norm_lstm.SetCellToForgetWeights(cell_to_forget_weights_);
- layer_norm_lstm.SetCellToOutputWeights(cell_to_output_weights_);
-
- layer_norm_lstm.SetForgetLayerNormWeights(forget_layer_norm_weights_);
- layer_norm_lstm.SetCellLayerNormWeights(cell_layer_norm_weights_);
- layer_norm_lstm.SetOutputLayerNormWeights(output_layer_norm_weights_);
-
- layer_norm_lstm.SetProjectionWeights(projection_weights_);
-
- // Verify the final output.
- const std::vector<std::vector<float>> layer_norm_lstm_golden_output = {
- {
- // Batch0: 3 (input_sequence_size) * 3 (n_output)
- 0.0212250091, 0.140474007, 0.0115012666, // seq 0
- 0.0130806509, 0.152660668, 0.0347516984, // seq 1
- -0.0124010444, 0.166042402, 0.0898982584, // seq 2
- },
- {
- // Batch1: 3 (input_sequence_size) * 3 (n_output)
- -0.0228835996, 0.0917588323, 0.0778886303, // seq 0
- -0.0275101066, 0.148769245, 0.0938384682, // seq 1
- -0.0103605557, 0.172605693, 0.0728750974, // seq 2
- }};
-
- VerifyGoldens(layer_norm_lstm_input_, layer_norm_lstm_golden_output,
- &layer_norm_lstm);
-}
-
-} // namespace
-} // namespace custom
-} // namespace ops
-} // namespace tflite
-
-int main(int argc, char** argv) {
- ::tflite::LogToStderr();
- ::testing::InitGoogleTest(&argc, argv);
- return RUN_ALL_TESTS();
-}
diff --git a/tensorflow/lite/kernels/mirror_pad.cc b/tensorflow/lite/kernels/mirror_pad.cc
index 70f7e33..65a98ef 100644
--- a/tensorflow/lite/kernels/mirror_pad.cc
+++ b/tensorflow/lite/kernels/mirror_pad.cc
@@ -37,6 +37,10 @@
// Note: This is not owned by default. It will point to the value
// in the input tensor.
const void* value = nullptr;
+ // The start index of the values of this tensor in the output buffer.
+ int start = -1;
+ // The end index of the values of this tensor in the output buffer.
+ int end = -1;
// If this tensor is not one value, then this vector will have
// all the tensors that belongs to this tensor.
// Pointers are not owned.
@@ -66,30 +70,37 @@
struct OpData {
// Holds intermediate data structure of the padded tensor.
std::vector<PaddedTensor> pad_tensor_buffer;
- // Total number of intermediate elements in the pad_tensor_buffer.
- int num_elements;
};
// Util method to initialize the memory of the padded tensor.
-// Returns the index of the current item processed in 'padded_tensor_buffer'
-int InitializeTensorMemory(const TfLiteIntArray* const dims, int dim_index,
- int dims_size,
- std::vector<PaddedTensor>* padded_tensor_buffer,
- int element_index, int num_elements) {
- if (dim_index >= dims_size) {
- return element_index;
+void InitializeTensorMemory(const TfLiteIntArray* const dims, int dims_size,
+ std::vector<PaddedTensor>* padded_tensor_buffer) {
+ int dimension_index = 0;
+ int element_index = 0;
+ // We hold 2 vectors with values for nodes in current level, and
+ // nodes in the next level, and swap while moving on dimensions of the tensor.
+ std::vector<PaddedTensor*> current_nodes, next_level;
+ current_nodes.push_back(&(*padded_tensor_buffer)[element_index]);
+ current_nodes[0]->start = current_nodes[0]->end = -1;
+ element_index++;
+ int next_level_size = 1;
+ while (!current_nodes.empty() && dimension_index < dims_size) {
+ next_level_size *= dims->data[dimension_index];
+ next_level.resize(next_level_size);
+ // Index of elements in next level.
+ int index = 0;
+ for (auto* padded_tensor : current_nodes) {
+ padded_tensor->values.resize(dims->data[dimension_index]);
+ for (int i = 0; i < dims->data[dimension_index]; ++i) {
+ padded_tensor->values[i] = &(*padded_tensor_buffer)[element_index];
+ padded_tensor->values[i]->start = padded_tensor->values[i]->end = -1;
+ next_level[index++] = padded_tensor->values[i];
+ element_index++;
+ }
+ }
+ std::swap(current_nodes, next_level);
+ dimension_index++;
}
- PaddedTensor* padded_tensor = &(*padded_tensor_buffer)[element_index];
- padded_tensor->values.clear();
- padded_tensor->values.reserve(dims->data[dim_index]);
- for (int i = 0; i < dims->data[dim_index]; ++i) {
- ++element_index;
- padded_tensor->values.emplace_back(&(*padded_tensor_buffer)[element_index]);
- element_index = InitializeTensorMemory(dims, dim_index + 1, dims_size,
- padded_tensor_buffer, element_index,
- num_elements);
- }
- return element_index;
}
// Returns pointer to the value at the specified index in 'data'.
@@ -117,20 +128,6 @@
return nullptr;
}
-// Util method that increment index in the N-d array.
-void IncrementTensorIndex(const TfLiteIntArray* dims,
- std::vector<int>* tensor_index_ptr) {
- int dimension_index = dims->size - 1;
- auto& tensor_index = *tensor_index_ptr;
- tensor_index[dimension_index]++;
- while (dimension_index >= 0 &&
- tensor_index[dimension_index] == dims->data[dimension_index]) {
- tensor_index[dimension_index] = 0;
- dimension_index--;
- if (dimension_index >= 0) tensor_index[dimension_index]++;
- }
-}
-
// Fills the 'padded_tensor' with data from 'input_tensor'.
TfLiteStatus InitFromInputTensor(const TfLiteTensor* input_tensor,
PaddedTensor* padded_tensor) {
@@ -145,13 +142,13 @@
std::vector<int> tensor_index(dims->size, 0);
int flat_index = 0;
const int num_elements = NumElements(input_tensor);
+ auto* tensor = padded_tensor->GetMutable(tensor_index);
while (flat_index < num_elements) {
- auto* tensor = padded_tensor->GetMutable(tensor_index);
if (tensor == nullptr) {
return kTfLiteError;
}
tensor->value = GetValuePointerAtIndex(data, flat_index, data_type);
- IncrementTensorIndex(dims, &tensor_index);
+ ++tensor;
++flat_index;
}
@@ -245,25 +242,40 @@
// Fills 'output_data' with data from 'padded_tensor'.
// The function does this recursively by setting left padding first then
// original data, followed by the right padding.
+// The functions returns the index in 'output_data' to be filled with data.
template <typename T>
-int FillOutput(const PaddedTensor* padded_tensor, T* output_data,
+int FillOutput(PaddedTensor* padded_tensor, T* output_data,
int index_in_output) {
if (padded_tensor == nullptr || output_data == nullptr) {
return -1;
}
+ // Check if this tensor value was computed and written in the output
+ // already. If yes, just copy the values.
+ if (padded_tensor->start != -1) {
+ const int size = padded_tensor->end - padded_tensor->start + 1;
+ memcpy(output_data + index_in_output, output_data + padded_tensor->start,
+ size * sizeof(T));
+ return index_in_output + size;
+ }
+ // Record the start index in the output.
+ padded_tensor->start = index_in_output;
+ // Check for single value.
if (padded_tensor->value != nullptr) {
output_data[index_in_output] = *static_cast<const T*>(padded_tensor->value);
+ padded_tensor->end = index_in_output;
return index_in_output + 1;
}
- for (const auto* tensor : padded_tensor->left_pad_ptrs) {
+ for (auto* tensor : padded_tensor->left_pad_ptrs) {
index_in_output = FillOutput(tensor, output_data, index_in_output);
}
- for (const auto& tensor : padded_tensor->values) {
+ for (auto& tensor : padded_tensor->values) {
index_in_output = FillOutput(tensor, output_data, index_in_output);
}
- for (const auto* tensor : padded_tensor->right_pad_ptrs) {
+ for (auto* tensor : padded_tensor->right_pad_ptrs) {
index_in_output = FillOutput(tensor, output_data, index_in_output);
}
+ // Record the end index in the output.
+ padded_tensor->end = index_in_output - 1;
return index_in_output;
}
@@ -308,8 +320,8 @@
PaddedTensor& padded_tensor = op_data->pad_tensor_buffer[0];
// Initialize memory.
- InitializeTensorMemory(input_tensor->dims, 0, input_dims,
- &op_data->pad_tensor_buffer, 0, op_data->num_elements);
+ InitializeTensorMemory(input_tensor->dims, input_dims,
+ &op_data->pad_tensor_buffer);
// Set the values from the input_tensor.
TF_LITE_ENSURE_STATUS(InitFromInputTensor(input_tensor, &padded_tensor));
const int offset =
@@ -380,7 +392,6 @@
num_elements += extra_nodes;
}
op_data->pad_tensor_buffer.resize(num_elements);
- op_data->num_elements = num_elements;
if (!IsConstantTensor(padding_matrix)) {
SetTensorToDynamic(output_tensor);
diff --git a/tensorflow/lite/kernels/mirror_pad_test.cc b/tensorflow/lite/kernels/mirror_pad_test.cc
index fd09e6e..91e48fa 100644
--- a/tensorflow/lite/kernels/mirror_pad_test.cc
+++ b/tensorflow/lite/kernels/mirror_pad_test.cc
@@ -185,5 +185,18 @@
EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 3, 2}));
}
+TEST(MirrorPadTest, Pad_1D_Symmetric_Multiple_Invoke) {
+ BaseMirrorPadOpModel<int> model(
+ {TensorType_INT32, {3}}, {TensorType_INT32, {1, 2}},
+ {TensorType_INT32, {}}, tflite::MirrorPadMode_SYMMETRIC);
+ model.PopulateTensor<int>(model.input_tensor_id(), {1, 2, 3});
+ model.PopulateTensor<int>(model.padding_matrix_tensor_id(), {0, 2});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({1, 2, 3, 3, 2}));
+ model.PopulateTensor<int>(model.input_tensor_id(), {4, 5, 6});
+ model.Invoke();
+ EXPECT_THAT(model.GetOutput(), ElementsAreArray({4, 5, 6, 6, 5}));
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc
index fcfe0b2..948bf77 100644
--- a/tensorflow/lite/kernels/register.cc
+++ b/tensorflow/lite/kernels/register.cc
@@ -22,10 +22,9 @@
namespace custom {
TfLiteRegistration* Register_AUDIO_SPECTROGRAM();
-TfLiteRegistration* Register_LAYER_NORM_LSTM();
TfLiteRegistration* Register_MFCC();
TfLiteRegistration* Register_DETECTION_POSTPROCESS();
-TfLiteRegistration* Register_RELU_1();
+TfLiteRegistration* Register_IF();
} // namespace custom
@@ -132,6 +131,7 @@
TfLiteRegistration* Register_MIRROR_PAD();
TfLiteRegistration* Register_UNIQUE();
TfLiteRegistration* Register_REVERSE_V2();
+TfLiteRegistration* Register_ADD_N();
TfLiteStatus UnsupportedTensorFlowOp(TfLiteContext* context, TfLiteNode* node) {
context->ReportError(
@@ -167,15 +167,17 @@
AddBuiltin(BuiltinOperator_RELU6, Register_RELU6());
AddBuiltin(BuiltinOperator_TANH, Register_TANH());
AddBuiltin(BuiltinOperator_LOGISTIC, Register_LOGISTIC());
- AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D, Register_AVERAGE_POOL_2D());
+ AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D, Register_AVERAGE_POOL_2D(),
+ /* min_version */ 1,
+ /* max_version */ 2);
AddBuiltin(BuiltinOperator_MAX_POOL_2D, Register_MAX_POOL_2D());
AddBuiltin(BuiltinOperator_L2_POOL_2D, Register_L2_POOL_2D());
AddBuiltin(BuiltinOperator_CONV_2D, Register_CONV_2D(),
/* min_version */ 1,
- /* max_version */ 2);
+ /* max_version */ 3);
AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, Register_DEPTHWISE_CONV_2D(),
/* min_version */ 1,
- /* max_version */ 2);
+ /* max_version */ 3);
AddBuiltin(BuiltinOperator_SVDF, Register_SVDF(),
/* min_version */ 1,
/* max_version */ 2);
@@ -200,7 +202,9 @@
/* max_version */ 3);
AddBuiltin(BuiltinOperator_LSH_PROJECTION, Register_LSH_PROJECTION());
AddBuiltin(BuiltinOperator_HASHTABLE_LOOKUP, Register_HASHTABLE_LOOKUP());
- AddBuiltin(BuiltinOperator_SOFTMAX, Register_SOFTMAX());
+ AddBuiltin(BuiltinOperator_SOFTMAX, Register_SOFTMAX(),
+ /* min_version */ 1,
+ /* max_version */ 2);
AddBuiltin(BuiltinOperator_CONCATENATION, Register_CONCATENATION());
AddBuiltin(BuiltinOperator_ADD, Register_ADD());
AddBuiltin(BuiltinOperator_SPACE_TO_BATCH_ND, Register_SPACE_TO_BATCH_ND());
@@ -225,12 +229,15 @@
Register_RESIZE_NEAREST_NEIGHBOR());
AddBuiltin(BuiltinOperator_SKIP_GRAM, Register_SKIP_GRAM());
AddBuiltin(BuiltinOperator_SPACE_TO_DEPTH, Register_SPACE_TO_DEPTH());
- AddBuiltin(BuiltinOperator_GATHER, Register_GATHER());
+ AddBuiltin(BuiltinOperator_GATHER, Register_GATHER(),
+ /* min_version */ 1,
+ /* max_version */ 2);
AddBuiltin(BuiltinOperator_TRANSPOSE, Register_TRANSPOSE());
AddBuiltin(BuiltinOperator_MEAN, Register_MEAN());
AddBuiltin(BuiltinOperator_DIV, Register_DIV());
AddBuiltin(BuiltinOperator_SUB, Register_SUB());
- AddBuiltin(BuiltinOperator_SPLIT, Register_SPLIT());
+ AddBuiltin(BuiltinOperator_SPLIT, Register_SPLIT(), /* min_version */ 1,
+ /* max_version */ 2);
AddBuiltin(BuiltinOperator_SPLIT_V, Register_SPLIT_V());
AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE());
AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE());
@@ -245,8 +252,12 @@
AddBuiltin(BuiltinOperator_PRELU, Register_PRELU());
AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM());
AddBuiltin(BuiltinOperator_MINIMUM, Register_MINIMUM());
- AddBuiltin(BuiltinOperator_ARG_MAX, Register_ARG_MAX());
- AddBuiltin(BuiltinOperator_ARG_MIN, Register_ARG_MIN());
+ AddBuiltin(BuiltinOperator_ARG_MAX, Register_ARG_MAX(),
+ /* min_version */ 1,
+ /* max_version */ 2);
+ AddBuiltin(BuiltinOperator_ARG_MIN, Register_ARG_MIN(),
+ /* min_version */ 1,
+ /* max_version */ 2);
AddBuiltin(BuiltinOperator_GREATER, Register_GREATER());
AddBuiltin(BuiltinOperator_GREATER_EQUAL, Register_GREATER_EQUAL());
AddBuiltin(BuiltinOperator_LESS, Register_LESS());
@@ -290,16 +301,18 @@
AddBuiltin(BuiltinOperator_MIRROR_PAD, Register_MIRROR_PAD());
AddBuiltin(BuiltinOperator_UNIQUE, Register_UNIQUE());
AddBuiltin(BuiltinOperator_REVERSE_V2, Register_REVERSE_V2());
+ AddBuiltin(BuiltinOperator_ADD_N, Register_ADD_N());
// TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
// custom ops aren't always included by default.
AddCustom("Mfcc", tflite::ops::custom::Register_MFCC());
AddCustom("AudioSpectrogram",
tflite::ops::custom::Register_AUDIO_SPECTROGRAM());
- AddCustom("LayerNormLstm", tflite::ops::custom::Register_LAYER_NORM_LSTM());
- AddCustom("Relu1", tflite::ops::custom::Register_RELU_1());
AddCustom("TFLite_Detection_PostProcess",
tflite::ops::custom::Register_DETECTION_POSTPROCESS());
+
+ // WARNING: Control flow ops are experimental and subject to change.
+ AddCustom("Experimental_If", tflite::ops::custom::Register_IF());
}
} // namespace builtin
diff --git a/tensorflow/lite/kernels/register_ref.cc b/tensorflow/lite/kernels/register_ref.cc
index 6840ea3..faa864b 100644
--- a/tensorflow/lite/kernels/register_ref.cc
+++ b/tensorflow/lite/kernels/register_ref.cc
@@ -22,10 +22,8 @@
namespace custom {
TfLiteRegistration* Register_AUDIO_SPECTROGRAM();
-TfLiteRegistration* Register_LAYER_NORM_LSTM();
TfLiteRegistration* Register_MFCC();
TfLiteRegistration* Register_DETECTION_POSTPROCESS();
-TfLiteRegistration* Register_RELU_1();
} // namespace custom
@@ -286,8 +284,6 @@
AddCustom("Mfcc", tflite::ops::custom::Register_MFCC());
AddCustom("AudioSpectrogram",
tflite::ops::custom::Register_AUDIO_SPECTROGRAM());
- AddCustom("LayerNormLstm", tflite::ops::custom::Register_LAYER_NORM_LSTM());
- AddCustom("Relu1", tflite::ops::custom::Register_RELU_1());
AddCustom("TFLite_Detection_PostProcess",
tflite::ops::custom::Register_DETECTION_POSTPROCESS());
}
diff --git a/tensorflow/lite/kernels/relu1.cc b/tensorflow/lite/kernels/relu1.cc
deleted file mode 100644
index 5a55631..0000000
--- a/tensorflow/lite/kernels/relu1.cc
+++ /dev/null
@@ -1,59 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/lite/context.h"
-#include "tensorflow/lite/kernels/internal/tensor.h"
-#include "tensorflow/lite/kernels/kernel_util.h"
-
-namespace tflite {
-namespace ops {
-namespace custom {
-namespace relu1 {
-
-TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
- TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
- TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
- const TfLiteTensor* input = GetInput(context, node, 0);
- TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
- TfLiteTensor* output = GetOutput(context, node, 0);
- output->type = input->type;
- return context->ResizeTensor(context, output,
- TfLiteIntArrayCopy(input->dims));
-}
-
-// This is derived from lite/kernels/activations.cc.
-TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- const TfLiteTensor* input = GetInput(context, node, 0);
- TfLiteTensor* output = GetOutput(context, node, 0);
- const int elements = NumElements(input);
- const float* in = input->data.f;
- const float* in_end = in + elements;
- float* out = output->data.f;
- for (; in < in_end; ++in, ++out) {
- *out = std::min(std::max(0.f, *in), 1.f);
- }
- return kTfLiteOk;
-}
-
-} // namespace relu1
-
-TfLiteRegistration* Register_RELU_1() {
- static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
- relu1::Prepare, relu1::Eval};
- return &r;
-}
-
-} // namespace custom
-} // namespace ops
-} // namespace tflite
diff --git a/tensorflow/lite/kernels/relu1_test.cc b/tensorflow/lite/kernels/relu1_test.cc
deleted file mode 100644
index f52d10b..0000000
--- a/tensorflow/lite/kernels/relu1_test.cc
+++ /dev/null
@@ -1,79 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include <gtest/gtest.h>
-#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
-#include "tensorflow/lite/kernels/register.h"
-#include "tensorflow/lite/kernels/test_util.h"
-
-namespace tflite {
-namespace ops {
-namespace custom {
-
-TfLiteRegistration* Register_RELU_1();
-
-namespace {
-
-using ::testing::ElementsAreArray;
-
-class BaseActivationsOpModel : public SingleOpModel {
- public:
- explicit BaseActivationsOpModel(const TensorData& input) {
- input_ = AddInput(input);
- output_ = AddOutput({input.type, {}});
- flexbuffers::Builder fbb;
- fbb.Map([&]() {});
- fbb.Finish();
- SetCustomOp("RELU_1", fbb.GetBuffer(), Register_RELU_1);
- BuildInterpreter({GetShape(input_)});
- }
-
- protected:
- int input_;
- int output_;
-};
-
-class FloatActivationsOpModel : public BaseActivationsOpModel {
- public:
- using BaseActivationsOpModel::BaseActivationsOpModel;
-
- void SetInput(std::initializer_list<float> data) {
- PopulateTensor(input_, data);
- }
- std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
-};
-
-TEST(FloatActivationsOpTest, Relu1) {
- FloatActivationsOpModel m(/*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}});
- m.SetInput({
- 0.0, -0.6, 0.2, -0.4, //
- 0.3, -2.0, 1.1, -0.1, //
- });
- m.Invoke();
- EXPECT_THAT(m.GetOutput(), ElementsAreArray({
- 0.0, 0.0, 0.2, 0.0, //
- 0.3, 0.0, 1.0, 0.0, //
- }));
-}
-
-} // namespace
-} // namespace custom
-} // namespace ops
-} // namespace tflite
-
-int main(int argc, char** argv) {
- ::tflite::LogToStderr();
- ::testing::InitGoogleTest(&argc, argv);
- return RUN_ALL_TESTS();
-}
diff --git a/tensorflow/lite/kernels/split.cc b/tensorflow/lite/kernels/split.cc
index 7902ed2..4112898 100644
--- a/tensorflow/lite/kernels/split.cc
+++ b/tensorflow/lite/kernels/split.cc
@@ -76,9 +76,9 @@
TF_LITE_ENSURE_EQ(context, NumOutputs(node), op_context.params->num_splits);
auto input_type = op_context.input->type;
- TF_LITE_ENSURE(context, input_type == kTfLiteFloat32 ||
- input_type == kTfLiteUInt8 ||
- input_type == kTfLiteInt16);
+ TF_LITE_ENSURE(context,
+ input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
+ input_type == kTfLiteInt8 || input_type == kTfLiteInt16);
for (int i = 0; i < NumOutputs(node); ++i) {
GetOutput(context, node, i)->type = input_type;
}
@@ -137,15 +137,19 @@
TF_LITE_SPLIT(uint8_t);
break;
}
+ case kTfLiteInt8: {
+ TF_LITE_SPLIT(int8_t);
+ break;
+ }
case kTfLiteInt16: {
TF_LITE_SPLIT(int16_t);
break;
}
default:
- context->ReportError(
- context,
- "Only float32, uint8 and int16 are currently supported, got %d.",
- op_context.input->type);
+ context->ReportError(context,
+ "Only float32, uint8, int8 and int16 are currently "
+ "supported, got %d.",
+ op_context.input->type);
return kTfLiteError;
}
#undef TF_LITE_SPLIT
diff --git a/tensorflow/lite/kernels/split_test.cc b/tensorflow/lite/kernels/split_test.cc
index f3d9ea3..aa23007 100644
--- a/tensorflow/lite/kernels/split_test.cc
+++ b/tensorflow/lite/kernels/split_test.cc
@@ -47,13 +47,15 @@
}
}
- void SetInput(std::initializer_list<float> data) {
+ template <typename T>
+ void SetInput(std::initializer_list<T> data) {
PopulateTensor(input_, data);
}
void SetAxis(int axis) { PopulateTensor(axis_, {axis}); }
- std::vector<float> GetOutput(int i) {
- return ExtractVector<float>(outputs_[i]);
+ template <typename T>
+ std::vector<T> GetOutput(int i) {
+ return ExtractVector<T>(outputs_[i]);
}
std::vector<int> GetOutputShape(int i) { return GetTensorShape(outputs_[i]); }
@@ -63,33 +65,34 @@
std::vector<int> outputs_;
};
-using TensorValues = std::initializer_list<float>;
-
+template <typename T>
void Check(int axis, int num_splits, std::initializer_list<int> input_shape,
std::initializer_list<int> output_shape,
- const TensorValues& input_data,
- const std::vector<TensorValues>& output_data) {
+ const std::initializer_list<T>& input_data,
+ const std::vector<std::initializer_list<T>>& output_data,
+ const TensorType& type = TensorType_FLOAT32) {
auto debug = [&](int i) {
std::stringstream ss;
ss << "for output tensor " << i << " axis=" << axis
<< " and num_splits=" << num_splits;
return ss.str();
};
- SplitOpModel m({TensorType_FLOAT32, input_shape}, num_splits);
+ SplitOpModel m({type, input_shape}, num_splits);
m.SetInput(input_data);
m.SetAxis(axis);
m.Invoke();
for (int i = 0; i < num_splits; ++i) {
- EXPECT_THAT(m.GetOutput(i), ElementsAreArray(output_data[i])) << debug(i);
+ EXPECT_THAT(m.GetOutput<T>(i), ElementsAreArray(output_data[i]))
+ << debug(i);
EXPECT_THAT(m.GetOutputShape(i), ElementsAreArray(output_shape))
<< debug(i);
}
- SplitOpModel const_m({TensorType_FLOAT32, input_shape}, num_splits, axis);
+ SplitOpModel const_m({type, input_shape}, num_splits, axis);
const_m.SetInput(input_data);
const_m.Invoke();
for (int i = 0; i < num_splits; ++i) {
- EXPECT_THAT(const_m.GetOutput(i), ElementsAreArray(output_data[i]))
+ EXPECT_THAT(const_m.GetOutput<T>(i), ElementsAreArray(output_data[i]))
<< debug(i);
EXPECT_THAT(const_m.GetOutputShape(i), ElementsAreArray(output_shape))
<< debug(i);
@@ -97,44 +100,75 @@
}
TEST(SplitOpTest, FourDimensional) {
- Check(/*axis=*/0, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2},
- {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
- {
- {1, 2, 3, 4, 5, 6, 7, 8},
- {9, 10, 11, 12, 13, 14, 15, 16},
- });
- Check(/*axis=*/1, /*num_splits=*/2, {2, 2, 2, 2}, {2, 1, 2, 2},
- {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
- {
- {1, 2, 3, 4, 9, 10, 11, 12},
- {5, 6, 7, 8, 13, 14, 15, 16},
- });
- Check(/*axis=*/2, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 1, 2},
- {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
- {
- {1, 2, 5, 6, 9, 10, 13, 14},
- {3, 4, 7, 8, 11, 12, 15, 16},
- });
- Check(/*axis=*/3, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 2, 1},
- {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
- {
- {1, 3, 5, 7, 9, 11, 13, 15},
- {2, 4, 6, 8, 10, 12, 14, 16},
- });
+ Check<float>(/*axis=*/0, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2},
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
+ {
+ {1, 2, 3, 4, 5, 6, 7, 8},
+ {9, 10, 11, 12, 13, 14, 15, 16},
+ });
+ Check<float>(/*axis=*/1, /*num_splits=*/2, {2, 2, 2, 2}, {2, 1, 2, 2},
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
+ {
+ {1, 2, 3, 4, 9, 10, 11, 12},
+ {5, 6, 7, 8, 13, 14, 15, 16},
+ });
+ Check<float>(/*axis=*/2, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 1, 2},
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
+ {
+ {1, 2, 5, 6, 9, 10, 13, 14},
+ {3, 4, 7, 8, 11, 12, 15, 16},
+ });
+ Check<float>(/*axis=*/3, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 2, 1},
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
+ {
+ {1, 3, 5, 7, 9, 11, 13, 15},
+ {2, 4, 6, 8, 10, 12, 14, 16},
+ });
+}
+
+TEST(SplitOpTest, FourDimensionalInt8) {
+ Check<int8_t>(/*axis=*/0, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2},
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
+ {
+ {1, 2, 3, 4, 5, 6, 7, 8},
+ {9, 10, 11, 12, 13, 14, 15, 16},
+ },
+ TensorType_INT8);
+ Check<int8_t>(/*axis=*/1, /*num_splits=*/2, {2, 2, 2, 2}, {2, 1, 2, 2},
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
+ {
+ {1, 2, 3, 4, 9, 10, 11, 12},
+ {5, 6, 7, 8, 13, 14, 15, 16},
+ },
+ TensorType_INT8);
+ Check<int8_t>(/*axis=*/2, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 1, 2},
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
+ {
+ {1, 2, 5, 6, 9, 10, 13, 14},
+ {3, 4, 7, 8, 11, 12, 15, 16},
+ },
+ TensorType_INT8);
+ Check<int8_t>(/*axis=*/3, /*num_splits=*/2, {2, 2, 2, 2}, {2, 2, 2, 1},
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
+ {
+ {1, 3, 5, 7, 9, 11, 13, 15},
+ {2, 4, 6, 8, 10, 12, 14, 16},
+ },
+ TensorType_INT8);
}
TEST(SplitOpTest, OneDimensional) {
- Check(/*axis=*/0, /*num_splits=*/8, {8}, {1}, {1, 2, 3, 4, 5, 6, 7, 8},
- {{1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}});
+ Check<float>(/*axis=*/0, /*num_splits=*/8, {8}, {1}, {1, 2, 3, 4, 5, 6, 7, 8},
+ {{1}, {2}, {3}, {4}, {5}, {6}, {7}, {8}});
}
TEST(SplitOpTest, NegativeAxis) {
- Check(/*axis=*/-4, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2},
- {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
- {
- {1, 2, 3, 4, 5, 6, 7, 8},
- {9, 10, 11, 12, 13, 14, 15, 16},
- });
+ Check<float>(/*axis=*/-4, /*num_splits=*/2, {2, 2, 2, 2}, {1, 2, 2, 2},
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
+ {
+ {1, 2, 3, 4, 5, 6, 7, 8},
+ {9, 10, 11, 12, 13, 14, 15, 16},
+ });
}
} // namespace
diff --git a/tensorflow/lite/kernels/subgraph_test_util.cc b/tensorflow/lite/kernels/subgraph_test_util.cc
new file mode 100644
index 0000000..0f41f9c
--- /dev/null
+++ b/tensorflow/lite/kernels/subgraph_test_util.cc
@@ -0,0 +1,159 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/kernels/subgraph_test_util.h"
+
+#include "flatbuffers/flexbuffers.h" // TF:flatbuffers
+#include "tensorflow/lite/core/subgraph.h"
+#include "tensorflow/lite/kernels/kernel_util.h"
+#include "tensorflow/lite/kernels/register.h"
+#include "tensorflow/lite/kernels/test_util.h"
+#include "tensorflow/lite/model.h"
+
+namespace tflite {
+
+namespace ops {
+namespace builtin {
+// ADD and MUL are used to test simple branch.
+TfLiteRegistration* Register_ADD();
+TfLiteRegistration* Register_MUL();
+// ADD and MUL are used to test dynamic sized subgraphs.
+TfLiteRegistration* Register_PAD();
+} // namespace builtin
+namespace custom {
+TfLiteRegistration* Register_IF();
+} // namespace custom
+} // namespace ops
+
+namespace subgraph_test_util {
+
+void SetupTensor(Subgraph* subgraph, int tensor_index, TfLiteType type) {
+ ASSERT_EQ(subgraph->SetTensorParametersReadWrite(tensor_index, type, "", 0,
+ nullptr, {}, false),
+ kTfLiteOk);
+}
+
+void BuildAddSubgraph(Subgraph* subgraph) {
+ int first_new_tensor_index;
+ ASSERT_EQ(subgraph->AddTensors(3, &first_new_tensor_index), kTfLiteOk);
+ ASSERT_EQ(first_new_tensor_index, 0);
+ ASSERT_EQ(subgraph->SetInputs({0, 1}), kTfLiteOk);
+ ASSERT_EQ(subgraph->SetOutputs({2}), kTfLiteOk);
+
+ SetupTensor(subgraph, 0, kTfLiteInt32);
+ SetupTensor(subgraph, 1, kTfLiteInt32);
+ SetupTensor(subgraph, 2, kTfLiteInt32);
+
+ TfLiteAddParams* params =
+ reinterpret_cast<TfLiteAddParams*>(malloc(sizeof(TfLiteAddParams)));
+ params->activation = kTfLiteActNone;
+ int node_index;
+ subgraph->AddNodeWithParameters({0, 1}, {2}, nullptr, 0, params,
+ ::tflite::ops::builtin::Register_ADD(),
+ &node_index);
+}
+
+// Build a subgraph with an mul op. Helper function for testing.
+void BuildMulSubgraph(Subgraph* subgraph) {
+ int first_new_tensor_index;
+ ASSERT_EQ(subgraph->AddTensors(3, &first_new_tensor_index), kTfLiteOk);
+ ASSERT_EQ(first_new_tensor_index, 0);
+ ASSERT_EQ(subgraph->SetInputs({0, 1}), kTfLiteOk);
+ ASSERT_EQ(subgraph->SetOutputs({2}), kTfLiteOk);
+
+ SetupTensor(subgraph, 0, kTfLiteInt32);
+ SetupTensor(subgraph, 1, kTfLiteInt32);
+ SetupTensor(subgraph, 2, kTfLiteInt32);
+
+ TfLiteMulParams* params =
+ reinterpret_cast<TfLiteMulParams*>(malloc(sizeof(TfLiteMulParams)));
+ params->activation = kTfLiteActNone;
+ int node_index;
+ subgraph->AddNodeWithParameters({0, 1}, {2}, nullptr, 0, params,
+ ::tflite::ops::builtin::Register_MUL(),
+ &node_index);
+}
+
+// Build a subgraph with a pad op. Helper function for testing.
+void BuildPadSubgraph(Subgraph* subgraph) {
+ int first_new_tensor_index;
+ ASSERT_EQ(subgraph->AddTensors(3, &first_new_tensor_index), kTfLiteOk);
+ ASSERT_EQ(first_new_tensor_index, 0);
+ ASSERT_EQ(subgraph->SetInputs({0, 1}), kTfLiteOk);
+ ASSERT_EQ(subgraph->SetOutputs({2}), kTfLiteOk);
+
+ SetupTensor(subgraph, 0, kTfLiteInt32);
+ SetupTensor(subgraph, 1, kTfLiteInt32);
+ SetupTensor(subgraph, 2, kTfLiteInt32);
+
+ TfLitePadParams* params =
+ reinterpret_cast<TfLitePadParams*>(malloc(sizeof(TfLitePadParams)));
+ int node_index;
+ subgraph->AddNodeWithParameters({0, 1}, {2}, nullptr, 0, params,
+ ::tflite::ops::builtin::Register_PAD(),
+ &node_index);
+}
+
+void BuildIfSubgraph(Subgraph* subgraph) {
+ int first_new_tensor_index;
+ ASSERT_EQ(subgraph->AddTensors(4, &first_new_tensor_index), kTfLiteOk);
+ ASSERT_EQ(first_new_tensor_index, 0);
+ ASSERT_EQ(subgraph->SetInputs({0, 1, 2}), kTfLiteOk);
+ ASSERT_EQ(subgraph->SetOutputs({3}), kTfLiteOk);
+
+ SetupTensor(subgraph, 0, kTfLiteBool);
+ SetupTensor(subgraph, 1, kTfLiteInt32);
+ SetupTensor(subgraph, 2, kTfLiteInt32);
+ SetupTensor(subgraph, 3, kTfLiteInt32);
+
+ flexbuffers::Builder fbb;
+ fbb.Map([&]() {
+ fbb.Int("then_subgraph_index", 1);
+ fbb.Int("else_subgraph_index", 2);
+ });
+ fbb.Finish();
+ const auto& buffer = fbb.GetBuffer();
+
+ int node_index;
+ subgraph->AddNodeWithParameters(
+ {0, 1, 2}, {3}, reinterpret_cast<const char*>(buffer.data()),
+ buffer.size(), nullptr, ::tflite::ops::custom::Register_IF(),
+ &node_index);
+}
+
+void FillIntTensor(TfLiteTensor* tensor, const std::vector<int32_t>& data) {
+ int count = NumElements(tensor);
+ ASSERT_EQ(count, data.size());
+ for (int i = 0; i < count; ++i) {
+ tensor->data.i32[i] = data[i];
+ }
+}
+
+void CheckIntTensor(const TfLiteTensor* tensor, const std::vector<int>& shape,
+ const std::vector<int32_t>& data) {
+ ASSERT_EQ(tensor->dims->size, shape.size());
+ for (int i = 0; i < tensor->dims->size; ++i) {
+ ASSERT_EQ(tensor->dims->data[i], shape[i]);
+ }
+ ASSERT_EQ(tensor->type, kTfLiteInt32);
+ int count = NumElements(tensor);
+ ASSERT_EQ(count, data.size());
+ for (int i = 0; i < count; ++i) {
+ EXPECT_EQ(tensor->data.i32[i], data[i]);
+ }
+}
+
+} // namespace subgraph_test_util
+} // namespace tflite
diff --git a/tensorflow/lite/kernels/subgraph_test_util.h b/tensorflow/lite/kernels/subgraph_test_util.h
new file mode 100644
index 0000000..6dc5207
--- /dev/null
+++ b/tensorflow/lite/kernels/subgraph_test_util.h
@@ -0,0 +1,62 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This module provides helper functions for testing the interaction between
+// control flow ops and subgraphs.
+// For convenience, we mostly only use `kTfLiteInt32` in this module.
+
+#ifndef TENSORFLOW_LITE_KERNELS_SUBGRAPH_TEST_UTIL_H_
+#define TENSORFLOW_LITE_KERNELS_SUBGRAPH_TEST_UTIL_H_
+
+#include "tensorflow/lite/core/subgraph.h"
+
+namespace tflite {
+namespace subgraph_test_util {
+
+// Build a subgraph with a single Add op.
+// 2 inputs. 1 output.
+void BuildAddSubgraph(Subgraph* subgraph);
+
+// Build a subgraph with a single Mul op.
+// 2 inputs. 1 output.
+void BuildMulSubgraph(Subgraph* subgraph);
+
+// Build a subgraph with a single Pad op.
+// 2 inputs. 1 output.
+void BuildPadSubgraph(Subgraph* subgraph);
+
+// Build a subgraph with a single If op.
+// 3 inputs:
+// The 1st input is condition with boolean type.
+// The 2nd and 3rd inputs are feed input the branch subgraphs.
+// 1 output.
+void BuildIfSubgraph(Subgraph* subgraph);
+
+// Fill a `TfLiteTensor` with a 32-bits integer vector.
+// Preconditions:
+// * The tensor must have `kTfLiteInt32` type.
+// * The tensor must be allocated.
+// * The element count of the tensor must be equal to the length or
+// the vector.
+void FillIntTensor(TfLiteTensor* tensor, const std::vector<int32_t>& data);
+
+// Check if the shape and data of a tensor is as expected.
+void CheckIntTensor(const TfLiteTensor* tensor, const std::vector<int>& shape,
+ const std::vector<int32_t>& data);
+
+} // namespace subgraph_test_util
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_KERNELS_SUBGRAPH_TEST_UTIL_H_
diff --git a/tensorflow/lite/kernels/subgraph_test_util_test.cc b/tensorflow/lite/kernels/subgraph_test_util_test.cc
new file mode 100644
index 0000000..88cd8f5
--- /dev/null
+++ b/tensorflow/lite/kernels/subgraph_test_util_test.cc
@@ -0,0 +1,84 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/kernels/subgraph_test_util.h"
+#include <gtest/gtest.h>
+#include "tensorflow/lite/interpreter.h"
+#include "tensorflow/lite/kernels/test_util.h"
+
+namespace tflite {
+
+namespace subgraph_test_util {
+
+namespace {
+
+// SubGraphTestUtilTest tests the helper functions defined in this file.
+TEST(SubGraphTestUtilTest, TestBuildAddSubgraph) {
+ std::unique_ptr<Interpreter> interpreter(new Interpreter);
+ BuildAddSubgraph(&interpreter->primary_subgraph());
+
+ interpreter->ResizeInputTensor(interpreter->inputs()[0], {2});
+ interpreter->ResizeInputTensor(interpreter->inputs()[1], {1, 2});
+ ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
+
+ FillIntTensor(interpreter->tensor(interpreter->inputs()[0]), {5, 7});
+ FillIntTensor(interpreter->tensor(interpreter->inputs()[1]), {1, 2});
+ ASSERT_EQ(interpreter->Invoke(), kTfLiteOk);
+
+ TfLiteTensor* output = interpreter->tensor(interpreter->outputs()[0]);
+ CheckIntTensor(output, {1, 2}, {6, 9});
+}
+
+TEST(SubGraphTestUtilTest, TestBuildMulSubgraph) {
+ std::unique_ptr<Interpreter> interpreter(new Interpreter);
+ BuildMulSubgraph(&interpreter->primary_subgraph());
+
+ interpreter->ResizeInputTensor(interpreter->inputs()[0], {2});
+ interpreter->ResizeInputTensor(interpreter->inputs()[1], {1, 2});
+ ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
+
+ FillIntTensor(interpreter->tensor(interpreter->inputs()[0]), {5, 7});
+ FillIntTensor(interpreter->tensor(interpreter->inputs()[1]), {1, 2});
+ ASSERT_EQ(interpreter->Invoke(), kTfLiteOk);
+
+ TfLiteTensor* output = interpreter->tensor(interpreter->outputs()[0]);
+ CheckIntTensor(output, {1, 2}, {5, 14});
+}
+
+TEST(SubGraphTestUtilTest, TestBuildPadSubgraph) {
+ std::unique_ptr<Interpreter> interpreter(new Interpreter);
+ BuildPadSubgraph(&interpreter->primary_subgraph());
+
+ interpreter->ResizeInputTensor(interpreter->inputs()[0], {2});
+ interpreter->ResizeInputTensor(interpreter->inputs()[1], {1, 2});
+ ASSERT_EQ(interpreter->AllocateTensors(), kTfLiteOk);
+
+ FillIntTensor(interpreter->tensor(interpreter->inputs()[0]), {5, 7});
+ FillIntTensor(interpreter->tensor(interpreter->inputs()[1]), {1, 2});
+ ASSERT_EQ(interpreter->Invoke(), kTfLiteOk);
+
+ TfLiteTensor* output = interpreter->tensor(interpreter->outputs()[0]);
+ CheckIntTensor(output, {5}, {0, 5, 7, 0, 0});
+}
+
+} // namespace
+} // namespace subgraph_test_util
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::tflite::LogToStderr();
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/lite/kernels/test_util.h b/tensorflow/lite/kernels/test_util.h
index 83f0868..9bec8ce 100644
--- a/tensorflow/lite/kernels/test_util.h
+++ b/tensorflow/lite/kernels/test_util.h
@@ -85,6 +85,24 @@
// the actual data is known. This mimics what happens in practice: quantization
// parameters are calculated during training or post training..
struct TensorData {
+ TensorData(TensorType type = TensorType_FLOAT32, std::vector<int> shape = {},
+ float min = 0.0f, float max = 0.0f, float scale = 0.0f,
+ int32_t zero_point = 0, bool per_channel_quantization = false,
+ std::vector<float> per_channel_quantization_scales = {},
+ std::vector<int64_t> per_channel_quantization_offsets = {},
+ int32_t channel_index = 0)
+ : type(type),
+ shape(shape),
+ min(min),
+ max(max),
+ scale(scale),
+ zero_point(zero_point),
+ per_channel_quantization(per_channel_quantization),
+ per_channel_quantization_scales(
+ std::move(per_channel_quantization_scales)),
+ per_channel_quantization_offsets(
+ std::move(per_channel_quantization_offsets)),
+ channel_index(channel_index) {}
TensorType type;
std::vector<int> shape;
float min;
diff --git a/tensorflow/lite/minimal_logging.cc b/tensorflow/lite/minimal_logging.cc
new file mode 100644
index 0000000..8768ef6
--- /dev/null
+++ b/tensorflow/lite/minimal_logging.cc
@@ -0,0 +1,44 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/minimal_logging.h"
+
+#include <cstdarg>
+
+namespace tflite {
+namespace logging_internal {
+
+void MinimalLogger::Log(LogSeverity severity, const char* format, ...) {
+ va_list args;
+ va_start(args, format);
+ VLog(severity, format, args);
+ va_end(args);
+}
+
+const char* MinimalLogger::GetSeverityName(LogSeverity severity) {
+ switch (severity) {
+ case TFLITE_LOG_INFO:
+ return "INFO";
+ case TFLITE_LOG_WARNING:
+ return "WARNING";
+ case TFLITE_LOG_ERROR:
+ return "ERROR";
+ default:
+ return "<Unknown severity>";
+ }
+}
+
+} // namespace logging_internal
+} // namespace tflite
diff --git a/tensorflow/lite/minimal_logging.h b/tensorflow/lite/minimal_logging.h
new file mode 100644
index 0000000..7682ed8
--- /dev/null
+++ b/tensorflow/lite/minimal_logging.h
@@ -0,0 +1,56 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_MINIMAL_LOGGING_H_
+#define TENSORFLOW_LITE_MINIMAL_LOGGING_H_
+
+#include <cstdarg>
+
+namespace tflite {
+
+enum LogSeverity {
+ TFLITE_LOG_INFO = 0,
+ TFLITE_LOG_WARNING = 1,
+ TFLITE_LOG_ERROR = 2,
+};
+
+namespace logging_internal {
+
+// Helper class for simple platform-specific console logging. Note that we
+// explicitly avoid the convenience of ostream-style logging to minimize binary
+// size impact.
+class MinimalLogger {
+ public:
+ // Logging hook that takes variadic args.
+ static void Log(LogSeverity severity, const char* format, ...);
+
+ // Logging hook that takes a formatted va_list.
+ static void VLog(LogSeverity severity, const char* format, va_list args);
+
+ private:
+ static const char* GetSeverityName(LogSeverity severity);
+};
+
+} // namespace logging_internal
+} // namespace tflite
+
+// Convenience macro for basic internal logging in production builds.
+// Note: This should never be used for debug-type logs, as it will *not* be
+// stripped in release optimized builds. In general, prefer the error reporting
+// APIs for developer-facing errors, and only use this for diagnostic output
+// that should always be logged in user builds.
+#define TFLITE_LOG_PROD(severity, format, ...) \
+ tflite::logging_internal::MinimalLogger::Log(severity, format, ##__VA_ARGS__);
+
+#endif // TENSORFLOW_LITE_MINIMAL_LOGGING_H_
diff --git a/tensorflow/lite/minimal_logging_android.cc b/tensorflow/lite/minimal_logging_android.cc
new file mode 100644
index 0000000..f87e6fa
--- /dev/null
+++ b/tensorflow/lite/minimal_logging_android.cc
@@ -0,0 +1,55 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/minimal_logging.h"
+
+#include <android/log.h>
+#include <cstdio>
+
+namespace tflite {
+namespace logging_internal {
+namespace {
+
+int GetPlatformSeverity(LogSeverity severity) {
+ switch (severity) {
+ case TFLITE_LOG_INFO:
+ return ANDROID_LOG_INFO;
+ case TFLITE_LOG_WARNING:
+ return ANDROID_LOG_WARN;
+ case TFLITE_LOG_ERROR:
+ return ANDROID_LOG_ERROR;
+ default:
+ return ANDROID_LOG_DEBUG;
+ }
+}
+
+} // namespace
+
+void MinimalLogger::VLog(LogSeverity severity, const char* format,
+ va_list args) {
+ // First log to Android's explicit log(cat) API.
+ va_list args_for_android_log;
+ va_copy(args_for_android_log, args);
+ __android_log_vprint(GetPlatformSeverity(severity), "tflite", format, args);
+ va_end(args_for_android_log);
+
+ // Also print to stderr for standard console applications.
+ fprintf(stderr, "%s: ", GetSeverityName(severity));
+ vfprintf(stderr, format, args);
+ fputc('\n', stderr);
+}
+
+} // namespace logging_internal
+} // namespace tflite
diff --git a/tensorflow/lite/minimal_logging_default.cc b/tensorflow/lite/minimal_logging_default.cc
new file mode 100644
index 0000000..9fa13e4
--- /dev/null
+++ b/tensorflow/lite/minimal_logging_default.cc
@@ -0,0 +1,31 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/minimal_logging.h"
+
+#include <cstdio>
+
+namespace tflite {
+namespace logging_internal {
+
+void MinimalLogger::VLog(LogSeverity severity, const char* format,
+ va_list args) {
+ fprintf(stderr, "%s: ", GetSeverityName(severity));
+ vfprintf(stderr, format, args);
+ fputc('\n', stderr);
+}
+
+} // namespace logging_internal
+} // namespace tflite
diff --git a/tensorflow/lite/minimal_logging_ios.cc b/tensorflow/lite/minimal_logging_ios.cc
new file mode 100644
index 0000000..a774682
--- /dev/null
+++ b/tensorflow/lite/minimal_logging_ios.cc
@@ -0,0 +1,47 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/minimal_logging.h"
+
+#include <syslog.h>
+#include <cstdarg>
+
+namespace tflite {
+namespace logging_internal {
+namespace {
+
+int GetPlatformSeverity(LogSeverity severity) {
+ switch (severity) {
+ case TFLITE_LOG_INFO:
+ return LOG_INFO;
+ case TFLITE_LOG_WARNING:
+ return LOG_WARNING;
+ case TFLITE_LOG_ERROR:
+ return LOG_ERR;
+ default:
+ return LOG_DEBUG;
+ }
+}
+
+} // namespace
+
+void MinimalLogger::VLog(LogSeverity severity, const char* format,
+ va_list args) {
+ // TODO(b/123704468): Use os_log when available.
+ vsyslog(GetPlatformSeverity(severity), format, args);
+}
+
+} // namespace logging_internal
+} // namespace tflite
diff --git a/tensorflow/lite/minimal_logging_test.cc b/tensorflow/lite/minimal_logging_test.cc
new file mode 100644
index 0000000..e59425a
--- /dev/null
+++ b/tensorflow/lite/minimal_logging_test.cc
@@ -0,0 +1,60 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/lite/minimal_logging.h"
+
+#include <string>
+
+#include <gtest/gtest.h>
+
+namespace tflite {
+
+TEST(MinimalLogging, Basic) {
+ testing::internal::CaptureStderr();
+ TFLITE_LOG_PROD(TFLITE_LOG_INFO, "Foo");
+ EXPECT_EQ("INFO: Foo\n", testing::internal::GetCapturedStderr());
+}
+
+TEST(MinimalLogging, BasicFormatted) {
+ testing::internal::CaptureStderr();
+ TFLITE_LOG_PROD(TFLITE_LOG_INFO, "Foo %s %s", "Bar", "Baz");
+ EXPECT_EQ("INFO: Foo Bar Baz\n", testing::internal::GetCapturedStderr());
+}
+
+TEST(MinimalLogging, Warn) {
+ testing::internal::CaptureStderr();
+ TFLITE_LOG_PROD(TFLITE_LOG_WARNING, "One", "");
+ EXPECT_EQ("WARNING: One\n", testing::internal::GetCapturedStderr());
+}
+
+TEST(MinimalLogging, Error) {
+ testing::internal::CaptureStderr();
+ TFLITE_LOG_PROD(TFLITE_LOG_ERROR, "Two");
+ EXPECT_EQ("ERROR: Two\n", testing::internal::GetCapturedStderr());
+}
+
+TEST(MinimalLogging, UnknownSeverity) {
+ testing::internal::CaptureStderr();
+ TFLITE_LOG_PROD(static_cast<LogSeverity>(-1), "Three");
+ EXPECT_EQ("<Unknown severity>: Three\n",
+ testing::internal::GetCapturedStderr());
+}
+
+} // namespace tflite
+
+int main(int argc, char** argv) {
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/tensorflow/lite/model.cc b/tensorflow/lite/model.cc
index c736685..e333138 100644
--- a/tensorflow/lite/model.cc
+++ b/tensorflow/lite/model.cc
@@ -246,11 +246,11 @@
TfLiteStatus InterpreterBuilder::ParseNodes(
const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators,
- Interpreter* interpreter) {
+ Subgraph* subgraph) {
TfLiteStatus status = kTfLiteOk;
// Reduce the number of redundant allocations
- interpreter->ReserveNodes(operators->Length());
+ subgraph->ReserveNodes(operators->Length());
for (int i = 0; i < operators->Length(); ++i) {
const auto* op = operators->Get(i);
@@ -280,7 +280,7 @@
}
if (op->custom_options()) {
- interpreter->AddNodeWithParameters(
+ subgraph->AddNodeWithParameters(
FlatBufferIntArrayToVector(op->inputs()),
FlatBufferIntArrayToVector(op->outputs()),
reinterpret_cast<const char*>(op->custom_options()->data()),
@@ -290,10 +290,9 @@
MallocDataAllocator malloc_allocator;
TF_LITE_ENSURE_STATUS(ParseOpData(op, op_type, error_reporter_,
&malloc_allocator, &builtin_data));
- interpreter->AddNodeWithParameters(
- FlatBufferIntArrayToVector(op->inputs()),
- FlatBufferIntArrayToVector(op->outputs()), nullptr, 0, builtin_data,
- registration);
+ subgraph->AddNodeWithParameters(FlatBufferIntArrayToVector(op->inputs()),
+ FlatBufferIntArrayToVector(op->outputs()),
+ nullptr, 0, builtin_data, registration);
}
}
@@ -353,11 +352,11 @@
TfLiteStatus InterpreterBuilder::ParseTensors(
const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
- Interpreter* interpreter) {
+ Subgraph* subgraph) {
TfLiteStatus status = kTfLiteOk;
// A little helper to get the names of inputs and outputs. Note that they
- // must outlive the interpreter.
+ // must outlive the subgraph.
auto get_name = [](const tflite::Tensor* t) -> const char* {
auto name = t->name();
if (name) return name->c_str();
@@ -418,7 +417,7 @@
status = kTfLiteError;
}
- if (interpreter->SetTensorParametersReadOnly(
+ if (subgraph->SetTensorParametersReadOnly(
i, type, get_name(tensor), dims, quantization, buffer_ptr,
buffer_size, allocation_) != kTfLiteOk) {
error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
@@ -426,9 +425,9 @@
status = kTfLiteError;
}
} else {
- if (interpreter->SetTensorParametersReadWrite(i, type, get_name(tensor),
- dims, quantization,
- is_variable) != kTfLiteOk) {
+ if (subgraph->SetTensorParametersReadWrite(i, type, get_name(tensor),
+ dims, quantization,
+ is_variable) != kTfLiteOk) {
error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
i);
status = kTfLiteError;
@@ -510,42 +509,56 @@
// Construct interpreter with correct number of tensors and operators.
auto* subgraphs = model_->subgraphs();
auto* buffers = model_->buffers();
- if (subgraphs->size() != 1) {
- error_reporter_->Report("Only 1 subgraph is currently supported.\n");
+
+ if (subgraphs->size() == 0) {
+ error_reporter_->Report("No subgraph in the model.\n");
return cleanup_and_error();
}
- const tflite::SubGraph* subgraph = (*subgraphs)[0];
- auto operators = subgraph->operators();
- auto tensors = subgraph->tensors();
- if (!operators || !tensors || !buffers) {
- error_reporter_->Report(
- "Did not get operators, tensors, or buffers in input flat buffer.\n");
- return cleanup_and_error();
- }
+
interpreter->reset(new Interpreter(error_reporter_));
- if ((**interpreter).AddTensors(tensors->Length()) != kTfLiteOk) {
- return cleanup_and_error();
+ (*interpreter)->SetNumThreads(num_threads);
+ if (subgraphs->Length() > 1) {
+ (*interpreter)->AddSubgraphs(subgraphs->Length() - 1);
}
- // Set num threads
- (**interpreter).SetNumThreads(num_threads);
- // Parse inputs/outputs
- (**interpreter).SetInputs(FlatBufferIntArrayToVector(subgraph->inputs()));
- (**interpreter).SetOutputs(FlatBufferIntArrayToVector(subgraph->outputs()));
- // Finally setup nodes and tensors
- if (ParseNodes(operators, interpreter->get()) != kTfLiteOk)
- return cleanup_and_error();
- if (ParseTensors(buffers, tensors, interpreter->get()) != kTfLiteOk)
- return cleanup_and_error();
-
- std::vector<int> variables;
- for (int i = 0; i < (*interpreter)->tensors_size(); ++i) {
- auto* tensor = (*interpreter)->tensor(i);
- if (tensor->is_variable) {
- variables.push_back(i);
+ for (int subgraph_index = 0; subgraph_index < subgraphs->Length();
+ ++subgraph_index) {
+ const tflite::SubGraph* subgraph = (*subgraphs)[subgraph_index];
+ tflite::Subgraph* modified_subgraph =
+ (*interpreter)->subgraph(subgraph_index);
+ auto operators = subgraph->operators();
+ auto tensors = subgraph->tensors();
+ if (!operators || !tensors || !buffers) {
+ error_reporter_->Report(
+ "Did not get operators, tensors, or buffers in subgraph %d.\n",
+ subgraph_index);
+ return cleanup_and_error();
}
+ if (modified_subgraph->AddTensors(tensors->Length()) != kTfLiteOk) {
+ return cleanup_and_error();
+ }
+ // Set num threads
+ // Parse inputs/outputs
+ modified_subgraph->SetInputs(
+ FlatBufferIntArrayToVector(subgraph->inputs()));
+ modified_subgraph->SetOutputs(
+ FlatBufferIntArrayToVector(subgraph->outputs()));
+
+ // Finally setup nodes and tensors
+ if (ParseNodes(operators, modified_subgraph) != kTfLiteOk)
+ return cleanup_and_error();
+ if (ParseTensors(buffers, tensors, modified_subgraph) != kTfLiteOk)
+ return cleanup_and_error();
+
+ std::vector<int> variables;
+ for (int i = 0; i < modified_subgraph->tensors_size(); ++i) {
+ auto* tensor = modified_subgraph->tensor(i);
+ if (tensor->is_variable) {
+ variables.push_back(i);
+ }
+ }
+ modified_subgraph->SetVariables(std::move(variables));
}
- (**interpreter).SetVariables(std::move(variables));
if (ApplyDelegates(interpreter->get()) != kTfLiteOk)
return cleanup_and_error();
diff --git a/tensorflow/lite/model.h b/tensorflow/lite/model.h
index a9bd4c9..bae4229 100644
--- a/tensorflow/lite/model.h
+++ b/tensorflow/lite/model.h
@@ -198,11 +198,11 @@
TfLiteStatus BuildLocalIndexToRegistrationMapping();
TfLiteStatus ParseNodes(
const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators,
- Interpreter* interpreter);
+ Subgraph* subgraph);
TfLiteStatus ParseTensors(
const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
- Interpreter* interpreter);
+ Subgraph* subgraph);
TfLiteStatus ApplyDelegates(Interpreter* interpreter);
TfLiteStatus ParseQuantization(const QuantizationParameters* src_quantization,
TfLiteQuantization* quantization);
diff --git a/tensorflow/lite/model_test.cc b/tensorflow/lite/model_test.cc
index e677ea9..67d2380 100644
--- a/tensorflow/lite/model_test.cc
+++ b/tensorflow/lite/model_test.cc
@@ -87,20 +87,21 @@
// Make sure currently unsupported # of subgraphs are checked
// TODO(aselle): Replace this test when multiple subgraphs are supported.
-TEST(BasicFlatBufferModel, TestZeroAndMultipleSubgraphs) {
- auto m1 = FlatBufferModel::BuildFromFile(
+TEST(BasicFlatBufferModel, TestZeroSubgraphs) {
+ auto m = FlatBufferModel::BuildFromFile(
"tensorflow/lite/testdata/0_subgraphs.bin");
- ASSERT_TRUE(m1);
- std::unique_ptr<Interpreter> interpreter1;
- ASSERT_NE(InterpreterBuilder(*m1, TrivialResolver())(&interpreter1),
- kTfLiteOk);
+ ASSERT_TRUE(m);
+ std::unique_ptr<Interpreter> interpreter;
+ ASSERT_NE(InterpreterBuilder(*m, TrivialResolver())(&interpreter), kTfLiteOk);
+}
- auto m2 = FlatBufferModel::BuildFromFile(
+TEST(BasicFlatBufferModel, TestMultipleSubgraphs) {
+ auto m = FlatBufferModel::BuildFromFile(
"tensorflow/lite/testdata/2_subgraphs.bin");
- ASSERT_TRUE(m2);
- std::unique_ptr<Interpreter> interpreter2;
- ASSERT_NE(InterpreterBuilder(*m2, TrivialResolver())(&interpreter2),
- kTfLiteOk);
+ ASSERT_TRUE(m);
+ std::unique_ptr<Interpreter> interpreter;
+ ASSERT_EQ(InterpreterBuilder(*m, TrivialResolver())(&interpreter), kTfLiteOk);
+ EXPECT_EQ(interpreter->subgraphs_size(), 2);
}
// Test what happens if we cannot bind any of the ops.
diff --git a/tensorflow/lite/nnapi/BUILD b/tensorflow/lite/nnapi/BUILD
index 7af2b09..8ee9f3c 100644
--- a/tensorflow/lite/nnapi/BUILD
+++ b/tensorflow/lite/nnapi/BUILD
@@ -31,6 +31,7 @@
],
linkopts = ["-ldl"] + select({
"//tensorflow:android": [],
+ "//tensorflow:darwin": [],
"//tensorflow:ios": [],
"//tensorflow:windows": [],
"//conditions:default": ["-lrt"],
diff --git a/tensorflow/lite/nnapi/NeuralNetworksTypes.h b/tensorflow/lite/nnapi/NeuralNetworksTypes.h
index 573500d..ba7eaf6 100644
--- a/tensorflow/lite/nnapi/NeuralNetworksTypes.h
+++ b/tensorflow/lite/nnapi/NeuralNetworksTypes.h
@@ -39,6 +39,7 @@
ANEURALNETWORKS_TENSOR_FLOAT32 = 3,
ANEURALNETWORKS_TENSOR_INT32 = 4,
ANEURALNETWORKS_TENSOR_QUANT8_ASYMM = 5,
+ ANEURALNETWORKS_TENSOR_QUANT8_SYMM = 13,
};
/**
diff --git a/tensorflow/lite/nnapi_delegate.cc b/tensorflow/lite/nnapi_delegate.cc
index f7cb158..f69baf1 100644
--- a/tensorflow/lite/nnapi_delegate.cc
+++ b/tensorflow/lite/nnapi_delegate.cc
@@ -664,6 +664,8 @@
case tflite::BuiltinOperator_UNIQUE:
case tflite::BuiltinOperator_CEIL:
case tflite::BuiltinOperator_REVERSE_V2:
+ case tflite::BuiltinOperator_ADD_N:
+ case tflite::BuiltinOperator_GATHER_ND:
logError("Op code %d is currently not delegated to NNAPI", builtin);
return kTfLiteError;
break;
diff --git a/tensorflow/lite/python/BUILD b/tensorflow/lite/python/BUILD
index 0036662..02b8b80 100644
--- a/tensorflow/lite/python/BUILD
+++ b/tensorflow/lite/python/BUILD
@@ -13,7 +13,6 @@
visibility = ["//visibility:public"],
deps = [
"//tensorflow/lite/python/interpreter_wrapper:tensorflow_wrap_interpreter_wrapper",
- "//tensorflow/python:util",
"//third_party/py/numpy",
],
)
@@ -23,7 +22,9 @@
srcs = ["interpreter_test.py"],
data = ["//tensorflow/lite/python/testdata:interpreter_test_data"],
srcs_version = "PY2AND3",
- tags = ["no_oss"],
+ tags = [
+ "no_windows",
+ ],
deps = [
":interpreter",
"//tensorflow/python:client_testlib",
@@ -38,6 +39,14 @@
srcs = ["tflite_convert.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
+ deps = [":tflite_convert_lib"],
+)
+
+py_library(
+ name = "tflite_convert_lib",
+ srcs = ["tflite_convert.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
deps = [
":lite",
],
@@ -69,6 +78,21 @@
shard_count = 4,
srcs_version = "PY2AND3",
tags = [
+ "no_windows",
+ ],
+ deps = [
+ ":lite",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_test_lib",
+ ],
+)
+
+py_test(
+ name = "lite_flex_test",
+ srcs = ["lite_flex_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ # TODO(b/111881877): Enable in oss after resolving op registry issues.
"no_oss",
"no_windows",
],
diff --git a/tensorflow/lite/python/interpreter_test.py b/tensorflow/lite/python/interpreter_test.py
index 7ec56a2..b217792 100644
--- a/tensorflow/lite/python/interpreter_test.py
+++ b/tensorflow/lite/python/interpreter_test.py
@@ -91,6 +91,41 @@
output_data = interpreter.get_tensor(output_details[0]['index'])
self.assertTrue((expected_output == output_data).all())
+ def testString(self):
+ interpreter = interpreter_wrapper.Interpreter(
+ model_path=resource_loader.get_path_to_datafile(
+ 'testdata/gather_string.tflite'))
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(2, len(input_details))
+ self.assertEqual('input', input_details[0]['name'])
+ self.assertEqual(np.string_, input_details[0]['dtype'])
+ self.assertTrue(([10] == input_details[0]['shape']).all())
+ self.assertEqual((0.0, 0), input_details[0]['quantization'])
+ self.assertEqual('indices', input_details[1]['name'])
+ self.assertEqual(np.int64, input_details[1]['dtype'])
+ self.assertTrue(([3] == input_details[1]['shape']).all())
+ self.assertEqual((0.0, 0), input_details[1]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual('output', output_details[0]['name'])
+ self.assertEqual(np.string_, output_details[0]['dtype'])
+ self.assertTrue(([3] == output_details[0]['shape']).all())
+ self.assertEqual((0.0, 0), output_details[0]['quantization'])
+
+ test_input = np.array([1, 2, 3], dtype=np.int64)
+ interpreter.set_tensor(input_details[1]['index'], test_input)
+
+ test_input = np.array(['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'])
+ expected_output = np.array([b'b', b'c', b'd'])
+ interpreter.set_tensor(input_details[0]['index'], test_input)
+ interpreter.invoke()
+
+ output_data = interpreter.get_tensor(output_details[0]['index'])
+ self.assertTrue((expected_output == output_data).all())
+
class InterpreterTestErrorPropagation(test_util.TensorFlowTestCase):
diff --git a/tensorflow/lite/python/interpreter_wrapper/BUILD b/tensorflow/lite/python/interpreter_wrapper/BUILD
index 6de6fb4..6ec7ce4 100644
--- a/tensorflow/lite/python/interpreter_wrapper/BUILD
+++ b/tensorflow/lite/python/interpreter_wrapper/BUILD
@@ -7,13 +7,25 @@
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
cc_library(
+ name = "numpy",
+ srcs = ["numpy.cc"],
+ hdrs = ["numpy.h"],
+ deps = [
+ "//third_party/py/numpy:headers",
+ "//third_party/python_runtime:headers",
+ ],
+)
+
+cc_library(
name = "interpreter_wrapper_lib",
srcs = ["interpreter_wrapper.cc"],
hdrs = ["interpreter_wrapper.h"],
deps = [
+ ":numpy",
":python_error_reporter",
":python_utils",
"//tensorflow/lite:framework",
+ "//tensorflow/lite:string_util",
"//tensorflow/lite/kernels:builtin_ops",
"//third_party/py/numpy:headers",
"//third_party/python_runtime:headers",
@@ -36,7 +48,9 @@
srcs = ["python_utils.cc"],
hdrs = ["python_utils.h"],
deps = [
+ ":numpy",
"//tensorflow/lite:framework",
+ "//tensorflow/lite:string_util",
"//third_party/py/numpy:headers",
"//third_party/python_runtime:headers",
],
diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc
index 9ccaabb..41cebf8 100644
--- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc
+++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc
@@ -21,16 +21,10 @@
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
+#include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
#include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
-
-// Disallow Numpy 1.7 deprecated symbols.
-#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
-
-#include <Python.h>
-
-#include "numpy/arrayobject.h"
-#include "numpy/ufuncobject.h"
+#include "tensorflow/lite/string_util.h"
#if PY_MAJOR_VERSION >= 3
#define PY_TO_CPPSTRING PyBytes_AsStringAndSize
@@ -64,12 +58,6 @@
namespace {
-// Calls PyArray's initialization to initialize all the API pointers. Note that
-// this usage implies only this translation unit can use the pointers. See
-// tensorflow/python/core/numpy.cc for a strategy if we ever need to extend
-// this further.
-void ImportNumpy() { import_array1(); }
-
std::unique_ptr<tflite::Interpreter> CreateInterpreter(
const tflite::FlatBufferModel* model,
const tflite::ops::builtin::BuiltinOpResolver& resolver) {
@@ -77,7 +65,7 @@
return nullptr;
}
- ImportNumpy();
+ ::tflite::python::ImportNumpy();
std::unique_ptr<tflite::Interpreter> interpreter;
if (tflite::InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) {
@@ -267,7 +255,7 @@
}
PyArrayObject* array = reinterpret_cast<PyArrayObject*>(array_safe.get());
- const TfLiteTensor* tensor = interpreter_->tensor(i);
+ TfLiteTensor* tensor = interpreter_->tensor(i);
if (python_utils::TfLiteTypeFromPyArray(array) != tensor->type) {
PyErr_Format(PyExc_ValueError,
@@ -279,26 +267,41 @@
}
if (PyArray_NDIM(array) != tensor->dims->size) {
- PyErr_SetString(PyExc_ValueError, "Cannot set tensor: Dimension mismatch");
+ PyErr_Format(PyExc_ValueError,
+ "Cannot set tensor: Dimension mismatch."
+ " Got %d"
+ " but expected %d for input %d.",
+ PyArray_NDIM(array), tensor->dims->size, i);
return nullptr;
}
for (int j = 0; j < PyArray_NDIM(array); j++) {
if (tensor->dims->data[j] != PyArray_SHAPE(array)[j]) {
- PyErr_SetString(PyExc_ValueError,
- "Cannot set tensor: Dimension mismatch");
+ PyErr_Format(PyExc_ValueError,
+ "Cannot set tensor: Dimension mismatch."
+ " Got %ld"
+ " but expected %d for dimension %d of input %d.",
+ PyArray_SHAPE(array)[j], tensor->dims->data[j], j, i);
return nullptr;
}
}
- size_t size = PyArray_NBYTES(array);
- if (size != tensor->bytes) {
- PyErr_Format(PyExc_ValueError,
- "numpy array had %zu bytes but expected %zu bytes.", size,
- tensor->bytes);
- return nullptr;
+ if (tensor->type != kTfLiteString) {
+ size_t size = PyArray_NBYTES(array);
+ if (size != tensor->bytes) {
+ PyErr_Format(PyExc_ValueError,
+ "numpy array had %zu bytes but expected %zu bytes.", size,
+ tensor->bytes);
+ return nullptr;
+ }
+ memcpy(tensor->data.raw, PyArray_DATA(array), size);
+ } else {
+ DynamicBuffer dynamic_buffer;
+ if (!python_utils::FillStringBufferWithPyArray(value, &dynamic_buffer)) {
+ return nullptr;
+ }
+ dynamic_buffer.WriteToTensor(tensor, nullptr);
}
- memcpy(tensor->data.raw, PyArray_DATA(array), size);
Py_RETURN_NONE;
}
@@ -345,19 +348,51 @@
std::vector<npy_intp> dims(tensor->dims->data,
tensor->dims->data + tensor->dims->size);
- // Make a buffer copy but we must tell Numpy It owns that data or else
- // it will leak.
- void* data = malloc(tensor->bytes);
- if (!data) {
- PyErr_SetString(PyExc_ValueError, "Malloc to copy tensor failed.");
- return nullptr;
+ if (tensor->type != kTfLiteString) {
+ // Make a buffer copy but we must tell Numpy It owns that data or else
+ // it will leak.
+ void* data = malloc(tensor->bytes);
+ if (!data) {
+ PyErr_SetString(PyExc_ValueError, "Malloc to copy tensor failed.");
+ return nullptr;
+ }
+ memcpy(data, tensor->data.raw, tensor->bytes);
+ PyObject* np_array =
+ PyArray_SimpleNewFromData(dims.size(), dims.data(), type_num, data);
+ PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(np_array),
+ NPY_ARRAY_OWNDATA);
+ return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
+ } else {
+ // Create a C-order array so the data is contiguous in memory.
+ const int32_t kCOrder = 0;
+ PyObject* py_object =
+ PyArray_EMPTY(dims.size(), dims.data(), NPY_OBJECT, kCOrder);
+
+ if (py_object == nullptr) {
+ PyErr_SetString(PyExc_MemoryError, "Failed to allocate PyArray.");
+ return nullptr;
+ }
+
+ PyArrayObject* py_array = reinterpret_cast<PyArrayObject*>(py_object);
+ PyObject** data = reinterpret_cast<PyObject**>(PyArray_DATA(py_array));
+ auto num_strings = GetStringCount(tensor->data.raw);
+ for (int j = 0; j < num_strings; ++j) {
+ auto ref = GetString(tensor->data.raw, j);
+
+ PyObject* bytes = PyBytes_FromStringAndSize(ref.str, ref.len);
+ if (bytes == nullptr) {
+ Py_DECREF(py_object);
+ PyErr_Format(PyExc_ValueError,
+ "Could not create PyBytes from string %d of input %d.", j,
+ i);
+ return nullptr;
+ }
+ // PyArray_EMPTY produces an array full of Py_None, which we must decref.
+ Py_DECREF(data[j]);
+ data[j] = bytes;
+ }
+ return py_object;
}
- memcpy(data, tensor->data.raw, tensor->bytes);
- PyObject* np_array =
- PyArray_SimpleNewFromData(dims.size(), dims.data(), type_num, data);
- PyArray_ENABLEFLAGS(reinterpret_cast<PyArrayObject*>(np_array),
- NPY_ARRAY_OWNDATA);
- return PyArray_Return(reinterpret_cast<PyArrayObject*>(np_array));
}
PyObject* InterpreterWrapper::tensor(PyObject* base_object, int i) {
diff --git a/tensorflow/lite/python/interpreter_wrapper/numpy.cc b/tensorflow/lite/python/interpreter_wrapper/numpy.cc
new file mode 100644
index 0000000..ff5403d
--- /dev/null
+++ b/tensorflow/lite/python/interpreter_wrapper/numpy.cc
@@ -0,0 +1,25 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define TFLITE_IMPORT_NUMPY // See numpy.h for explanation.
+#include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
+
+namespace tflite {
+namespace python {
+
+void ImportNumpy() { import_array1(); }
+
+} // namespace python
+} // namespace tflite
diff --git a/tensorflow/lite/python/interpreter_wrapper/numpy.h b/tensorflow/lite/python/interpreter_wrapper/numpy.h
new file mode 100644
index 0000000..a3b013f
--- /dev/null
+++ b/tensorflow/lite/python/interpreter_wrapper/numpy.h
@@ -0,0 +1,62 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_NUMPY_H_
+#define TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_NUMPY_H_
+
+#ifdef PyArray_Type
+#error "Numpy cannot be included before numpy.h."
+#endif
+
+// Disallow Numpy 1.7 deprecated symbols.
+#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
+
+// To handle PyArray_* calles, numpy defines a static lookup table called
+// PyArray_API, or PY_ARRAY_UNIQUE_SYMBOL, if defined. This causes the
+// PyArray_* pointers to be different for different translation units, unless
+// we take care of selectivel defined NO_IMPORT_ARRAY.
+//
+// Virtually every usage will define NO_IMPORT_ARRAY, and will have access to
+// the lookup table via:
+// extern void **PyArray_API;
+// In numpy.cc we will define TFLITE_IMPORT_NUMPY, effectively disabling that
+// and instead using:
+// void **PyArray_API;
+// which is initialized when ImportNumpy() is called.
+//
+// If we don't define PY_ARRAY_UNIQUE_SYMBOL then PyArray_API is a static
+// variable, which causes strange crashes when the pointers are used across
+// translation unit boundaries.
+//
+// For mone info see https://sourceforge.net/p/numpy/mailman/message/5700519
+// See also tensorflow/python/lib/core/numpy.h for a similar approach.
+#define PY_ARRAY_UNIQUE_SYMBOL _tensorflow_numpy_api
+#ifndef TFLITE_IMPORT_NUMPY
+#define NO_IMPORT_ARRAY
+#endif
+
+#include <Python.h>
+
+#include "numpy/arrayobject.h"
+#include "numpy/ufuncobject.h"
+
+namespace tflite {
+namespace python {
+
+void ImportNumpy();
+
+} // namespace python
+} // namespace tflite
+
+#endif // TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_NUMPY_H_
diff --git a/tensorflow/lite/python/interpreter_wrapper/python_utils.cc b/tensorflow/lite/python/interpreter_wrapper/python_utils.cc
index 2dc6043..a052ca3 100644
--- a/tensorflow/lite/python/interpreter_wrapper/python_utils.cc
+++ b/tensorflow/lite/python/interpreter_wrapper/python_utils.cc
@@ -15,9 +15,19 @@
#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
+#include <memory>
+
+#include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
+
namespace tflite {
namespace python_utils {
+struct PyObjectDereferencer {
+ void operator()(PyObject* py_object) const { Py_DECREF(py_object); }
+};
+
+using UniquePyObjectRef = std::unique_ptr<PyObject, PyObjectDereferencer>;
+
int TfLiteTypeToPyArrayType(TfLiteType tf_lite_type) {
switch (tf_lite_type) {
case kTfLiteFloat32:
@@ -33,7 +43,7 @@
case kTfLiteInt64:
return NPY_INT64;
case kTfLiteString:
- return NPY_OBJECT;
+ return NPY_STRING;
case kTfLiteBool:
return NPY_BOOL;
case kTfLiteComplex64:
@@ -73,5 +83,82 @@
return kTfLiteNoType;
}
+#if PY_VERSION_HEX >= 0x03030000
+bool FillStringBufferFromPyUnicode(PyObject* value,
+ DynamicBuffer* dynamic_buffer) {
+ Py_ssize_t len = -1;
+ char* buf = PyUnicode_AsUTF8AndSize(value, &len);
+ if (buf == NULL) {
+ PyErr_SetString(PyExc_ValueError, "PyUnicode_AsUTF8AndSize() failed.");
+ return false;
+ }
+ dynamic_buffer->AddString(buf, len);
+ return true;
+}
+#else
+bool FillStringBufferFromPyUnicode(PyObject* value,
+ DynamicBuffer* dynamic_buffer) {
+ UniquePyObjectRef utemp(PyUnicode_AsUTF8String(value));
+ if (!utemp) {
+ PyErr_SetString(PyExc_ValueError, "PyUnicode_AsUTF8String() failed.");
+ return false;
+ }
+ char* buf = nullptr;
+ Py_ssize_t len = -1;
+ if (PyBytes_AsStringAndSize(utemp.get(), &buf, &len) == -1) {
+ PyErr_SetString(PyExc_ValueError, "PyBytes_AsStringAndSize() failed.");
+ return false;
+ }
+ dynamic_buffer->AddString(buf, len);
+ return true;
+}
+#endif
+
+bool FillStringBufferFromPyString(PyObject* value,
+ DynamicBuffer* dynamic_buffer) {
+ if (PyUnicode_Check(value)) {
+ return FillStringBufferFromPyUnicode(value, dynamic_buffer);
+ }
+
+ char* buf = nullptr;
+ Py_ssize_t len = -1;
+ if (PyBytes_AsStringAndSize(value, &buf, &len) == -1) {
+ PyErr_SetString(PyExc_ValueError, "PyBytes_AsStringAndSize() failed.");
+ return false;
+ }
+ dynamic_buffer->AddString(buf, len);
+ return true;
+}
+
+bool FillStringBufferWithPyArray(PyObject* value,
+ DynamicBuffer* dynamic_buffer) {
+ PyArrayObject* array = reinterpret_cast<PyArrayObject*>(value);
+ switch (PyArray_TYPE(array)) {
+ case NPY_OBJECT:
+ case NPY_STRING:
+ case NPY_UNICODE: {
+ UniquePyObjectRef iter(PyArray_IterNew(value));
+ while (PyArray_ITER_NOTDONE(iter.get())) {
+ UniquePyObjectRef item(PyArray_GETITEM(
+ array, reinterpret_cast<char*>(PyArray_ITER_DATA(iter.get()))));
+
+ if (!FillStringBufferFromPyString(item.get(), dynamic_buffer)) {
+ return false;
+ }
+
+ PyArray_ITER_NEXT(iter.get());
+ }
+ return true;
+ }
+ default:
+ break;
+ }
+
+ PyErr_Format(PyExc_ValueError,
+ "Cannot use numpy array of type %d for string tensor.",
+ PyArray_TYPE(array));
+ return false;
+}
+
} // namespace python_utils
} // namespace tflite
diff --git a/tensorflow/lite/python/interpreter_wrapper/python_utils.h b/tensorflow/lite/python/interpreter_wrapper/python_utils.h
index 30a4422..5ffd231 100644
--- a/tensorflow/lite/python/interpreter_wrapper/python_utils.h
+++ b/tensorflow/lite/python/interpreter_wrapper/python_utils.h
@@ -17,14 +17,8 @@
#define TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_UTILS_H_
#include "tensorflow/lite/context.h"
-
-// Disallow Numpy 1.7 deprecated symbols.
-#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
-
-#include <Python.h>
-
-#include "numpy/arrayobject.h"
-#include "numpy/ufuncobject.h"
+#include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
+#include "tensorflow/lite/string_util.h"
namespace tflite {
namespace python_utils {
@@ -33,6 +27,9 @@
TfLiteType TfLiteTypeFromPyArray(PyArrayObject* array);
+bool FillStringBufferWithPyArray(PyObject* value,
+ DynamicBuffer* dynamic_buffer);
+
} // namespace python_utils
} // namespace tflite
#endif // TENSORFLOW_LITE_PYTHON_INTERPRETER_WRAPPER_PYTHON_UTILS_H_
diff --git a/tensorflow/lite/python/lite_flex_test.py b/tensorflow/lite/python/lite_flex_test.py
new file mode 100644
index 0000000..a5ae629
--- /dev/null
+++ b/tensorflow/lite/python/lite_flex_test.py
@@ -0,0 +1,58 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for lite.py functionality related to select TF op usage."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.lite.python import lite
+from tensorflow.lite.python.interpreter import Interpreter
+from tensorflow.python.client import session
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+@test_util.run_v1_only('b/120545219')
+class FromSessionTest(test_util.TensorFlowTestCase):
+
+ def testFlexMode(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ out_tensor = in_tensor + in_tensor
+ sess = session.Session()
+
+ # Convert model and ensure model is not None.
+ converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+ [out_tensor])
+ converter.target_ops = set([lite.OpsSet.SELECT_TF_OPS])
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Ensures the model contains TensorFlow ops.
+ # TODO(nupurgarg): Check values once there is a Python delegate interface.
+ interpreter = Interpreter(model_content=tflite_model)
+ with self.assertRaises(RuntimeError) as error:
+ interpreter.allocate_tensors()
+ self.assertIn(
+ 'Regular TensorFlow ops are not supported by this interpreter. Make '
+ 'sure you invoke the Flex delegate before inference.',
+ str(error.exception))
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/lite/python/lite_test.py b/tensorflow/lite/python/lite_test.py
index 83fd56b..ca6c5b8 100644
--- a/tensorflow/lite/python/lite_test.py
+++ b/tensorflow/lite/python/lite_test.py
@@ -131,13 +131,13 @@
input_details = interpreter.get_input_details()
self.assertEqual(1, len(input_details))
self.assertEqual('Placeholder', input_details[0]['name'])
- self.assertEqual(np.object_, input_details[0]['dtype'])
+ self.assertEqual(np.string_, input_details[0]['dtype'])
self.assertTrue(([4] == input_details[0]['shape']).all())
output_details = interpreter.get_output_details()
self.assertEqual(1, len(output_details))
self.assertEqual('Reshape', output_details[0]['name'])
- self.assertEqual(np.object_, output_details[0]['dtype'])
+ self.assertEqual(np.string_, output_details[0]['dtype'])
self.assertTrue(([2, 2] == output_details[0]['shape']).all())
# TODO(b/122659643): Test setting/getting string data via the python
# interpreter API after support has been added.
@@ -511,29 +511,6 @@
# Ensure that the quantized weights tflite model is smaller.
self.assertTrue(len(quantized_tflite) < len(float_tflite))
- def testFlexMode(self):
- in_tensor = array_ops.placeholder(
- shape=[1, 16, 16, 3], dtype=dtypes.float32)
- out_tensor = in_tensor + in_tensor
- sess = session.Session()
-
- # Convert model and ensure model is not None.
- converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
- [out_tensor])
- converter.target_ops = set([lite.OpsSet.SELECT_TF_OPS])
- tflite_model = converter.convert()
- self.assertTrue(tflite_model)
-
- # Ensures the model contains TensorFlow ops.
- # TODO(nupurgarg): Check values once there is a Python delegate interface.
- interpreter = Interpreter(model_content=tflite_model)
- with self.assertRaises(RuntimeError) as error:
- interpreter.allocate_tensors()
- self.assertIn(
- 'Regular TensorFlow ops are not supported by this interpreter. Make '
- 'sure you invoke the Flex delegate before inference.',
- str(error.exception))
-
def testFloatTocoConverter(self):
"""Tests deprecated test TocoConverter."""
in_tensor = array_ops.placeholder(
diff --git a/tensorflow/lite/python/op_hint.py b/tensorflow/lite/python/op_hint.py
index 8df37c1..f107aba 100644
--- a/tensorflow/lite/python/op_hint.py
+++ b/tensorflow/lite/python/op_hint.py
@@ -71,6 +71,7 @@
import collections as _collections
import copy as _copy
+import json as _json
import uuid as _uuid
import six as _six
@@ -132,6 +133,14 @@
# "stuff", "foo", "bar", -1 (where -1 is unused). So you would set this
# attribute to [2, 0, 1, -1].
TFLITE_INPUT_INDICES = "_tflite_input_indices"
+ # OpHint level.
+ FUNCTION_LEVEL_ATTR = "_tflite_ophint_level"
+ # Ophint internal mapping, this is for high level Ophint only.
+ # This basically contains three kinds of mapping:
+ # 1) How parental ophinted inputs map to the first child ophinted inputs;
+ # 2) How internal children nodes are connected;
+ # 3) How parental ophinted outputs map to the last child ophinted outputs.
+ CHILDREN_INPUTS_MAPPINGS = "_tflite_children_ophint_inputs_mapping"
# Types of aggregations
# stack: stacks all ophints with matching tags. i.e. for a static rnn.
@@ -149,10 +158,16 @@
"""Conceptually tracks indices of arguments of "OpHint functions".
The inputs and arguments of these functions both use an instance
- of the class so they can have independent numbering."""
+ of the class so they can have independent numbering.
+ """
- def __init__(self, function_name, unique_function_id, node_name_prefix,
- attr_name):
+ def __init__(self,
+ function_name,
+ unique_function_id,
+ node_name_prefix,
+ attr_name,
+ level=1,
+ children_inputs_mappings=None):
"""Initialize ophint argument.
Args:
@@ -161,6 +176,8 @@
node_name_prefix: How identities that are created are named.
attr_name: Name of attribute to use to store the index for this hint.
i.e. FUNCTION_INPUT_INDEX or FUNCTION_OUTPUT_INDEX
+ level: Hierarchical level of the Ophint node, a number.
+ children_inputs_mappings: Inputs/Outputs mapping for children hints.
"""
# The global index is the argument index of the op. This is in contrast
@@ -176,6 +193,8 @@
self._tag_to_next_sort_index = {} # The current index for each tag
self._node_name_prefix = node_name_prefix
self._attr_name = attr_name
+ self._level = level
+ self._children_inputs_mappings = children_inputs_mappings
def _get_new_global_index(self, index_override):
"""Return the next unused argument index in order or use an override.
@@ -251,6 +270,7 @@
uuid = self._unique_function_id
name = "%s-%s-%s-%r-%r-%s" % (self._node_name_prefix, self._function_name,
uuid, global_index, sort_index, name)
+
identity_op = _array_ops.identity(arg, name=name)
# pylint: disable=protected-access
@@ -264,6 +284,15 @@
s=_compat.as_bytes(self._unique_function_id)))
identity_op.op._set_attr(
self._attr_name, _attr_value_pb2.AttrValue(i=global_index))
+ identity_op.op._set_attr(OpHint.FUNCTION_LEVEL_ATTR,
+ _attr_value_pb2.AttrValue(i=self._level))
+ if self._children_inputs_mappings:
+ identity_op.op._set_attr(
+ OpHint.CHILDREN_INPUTS_MAPPINGS,
+ _attr_value_pb2.AttrValue(
+ s=_compat.as_bytes(_json.dumps(
+ self._children_inputs_mappings))))
+
if sort_index is not None:
identity_op.op._set_attr(
OpHint.FUNCTION_SORT_INDEX_ATTR,
@@ -275,23 +304,74 @@
# pylint: enable=protected-access
return identity_op
- def __init__(self, function_name, **kwargs):
+ def __init__(self,
+ function_name,
+ level=1,
+ children_inputs_mappings=None,
+ **kwargs):
"""Create a OpHint.
Args:
function_name: Name of the function (the custom op name in tflite)
+ level: OpHint level.
+ children_inputs_mappings: Children OpHint inputs/outputs mapping.
+ children_inputs_mappings should like below:
+ "parent_first_child_input":
+ [{"parent_input_index": num, "child_input_index": num}, ...]
+ "parent_last_child_output":
+ [{"parent_output_index": num, "child_output_index": num}, ...]
+ "internal_children_input_output":
+ [{"child_input_index": num, "child_output_index": num}, ...]
**kwargs: Keyword arguments of any constant attributes for the function.
"""
self._function_name = function_name
+ self._level = level
+ if self._level == 1:
+ assert children_inputs_mappings is None
+ else:
+ assert isinstance(children_inputs_mappings, dict)
+ self._children_inputs_mappings = children_inputs_mappings
+ if self._children_inputs_mappings is not None:
+ self._validate_children_inputs_mappings(self._children_inputs_mappings)
self._unique_function_id = _uuid.uuid1().hex # TODO(aselle): Unique enough?
self._attrs_to_store_later = kwargs
self._stored_attrs = False
self._inputs = OpHint.OpHintArgumentTracker(
self._function_name, self._unique_function_id, "InputHint",
- OpHint.FUNCTION_INPUT_INDEX_ATTR)
+ OpHint.FUNCTION_INPUT_INDEX_ATTR, level, self._children_inputs_mappings)
self._outputs = OpHint.OpHintArgumentTracker(
self._function_name, self._unique_function_id, "OutputHint",
- OpHint.FUNCTION_OUTPUT_INDEX_ATTR)
+ OpHint.FUNCTION_OUTPUT_INDEX_ATTR, level,
+ self._children_inputs_mappings)
+
+ def _validate_children_inputs_mappings(self, children_inputs_mappings):
+ """Validate children inputs mappings is in the right format.
+
+ Args:
+ children_inputs_mappings: the Children ophint inputs/outputs mapping.
+ """
+ assert isinstance(children_inputs_mappings, dict)
+ assert "parent_first_child_input" in children_inputs_mappings
+ assert "parent_last_child_output" in children_inputs_mappings
+ assert "internal_children_input_output" in children_inputs_mappings
+
+ # validate parent_first_child_input.
+
+ def assert_dictlist_has_keys(dictlist, keys):
+ for dikt in dictlist:
+ assert isinstance(dikt, dict)
+ for key in keys:
+ assert key in dikt
+
+ assert_dictlist_has_keys(
+ children_inputs_mappings["parent_first_child_input"],
+ ["parent_ophint_input_index", "first_child_ophint_input_index"])
+ assert_dictlist_has_keys(
+ children_inputs_mappings["parent_last_child_output"],
+ ["parent_output_index", "child_output_index"])
+ assert_dictlist_has_keys(
+ children_inputs_mappings["internal_children_input_output"],
+ ["child_input_index", "child_output_index"])
def _setattr(self, dest_op, name, value):
tensor_value = _ops.convert_to_tensor(value)
@@ -382,7 +462,7 @@
class _LiteOperand(object):
- """Abstract operand for a tflite hint function.
+ """Abstract operand for a tflite hint function._dynamic_rnn_loop.
This is a base class that handles representing arguments to an OpHint.
It also is able to serialize operands to the stubbed graph_def.
@@ -580,15 +660,18 @@
This is uses to accumulate found hints in the graphdef into a single
conceptual unit.
- Properties:
- self.inputs: inputs to the op (hash from index # to argument)
- self.outputs: outputs to the op (hash from index # to argument)
- self.function_name: the tflite custom op name to use
- self.uuid: a unique call id for this particular call (i.e.
+ Attributes:
+ inputs: inputs to the op (hash from index # to argument)
+ outputs: outputs to the op (hash from index # to argument)
+ function_name: the tflite custom op name to use
+ uuid: a unique call id for this particular call (i.e.
multiple function calls would have the same function_name but different
uuids.
- self.params: A param name to key value for op constant data. I.e. for
+ params: A param name to key value for op constant data. I.e. for
axis on a reduction, strides on a convolution, etc.
+ level: Level of the OpHint.
+ children_inputs_mappings: If the Ophint has children, children inputs
+ mappings indicate how their inputs & outputs are mapped.
"""
def __init__(self):
@@ -597,6 +680,8 @@
self.function_name = None
self.uuid = None
self.params = {}
+ self.level = -1
+ self.children_inputs_mappings = {}
def flattened_inputs_and_outputs(self):
"""Return a list of inputs and outputs in a flattened format.
@@ -622,22 +707,25 @@
inputs_str = "\tInputs\n" + format_args(self.inputs)
outputs_str = "\tOutputs\n" + format_args(self.outputs)
- return ("tflite function %s call %s\n\tinputs:\n\t\t%s\n\toutputs:\n\t\t%s"
- % (self.function_name, self.uuid, inputs_str, outputs_str))
+ return (
+ "tflite function %s call %s level %d "
+ "\n\tinputs:\n\t\t%s\n\toutputs:\n\t\t%s" %
+ (self.function_name, self.uuid, self.level, inputs_str, outputs_str))
-def _find_all_hints_in_graph_def(graphdef):
- """Look at the current default graph and return a list of LiteFuncCall objs.
+def _find_all_hints_in_nodes(nodes):
+ """Look at the all the input nodes and return a list of LiteFuncCall objs.
Args:
- graphdef: A TensorFlow graph_def to look for LiteFuncCalls.
+ nodes: A TensorFlow graph_def to look for LiteFuncCalls.
+
Returns:
a list of `LifeFuncCall` objects in the form
"""
func_calls = _collections.defaultdict(_LiteFuncCall)
- for node in graphdef.node:
+ for node in nodes:
attr = node.attr
# This is an op hint if it has a FUNCTION_UUID_ATTR, otherwise skip
uuid = attr[OpHint.FUNCTION_UUID_ATTR].s
@@ -649,6 +737,7 @@
call_def = func_calls[uuid]
call_def.uuid = uuid
call_def.function_name = attr[OpHint.FUNCTION_NAME_ATTR].s
+ call_def.level = attr[OpHint.FUNCTION_LEVEL_ATTR].i
# Get sorting and aggregation information
sort = (attr[OpHint.FUNCTION_SORT_INDEX_ATTR].i
@@ -658,6 +747,10 @@
if OpHint.FUNCTION_AGGREGATE_ATTR in attr:
aggregation = _compat.as_text(attr[OpHint.FUNCTION_AGGREGATE_ATTR].s)
+ if OpHint.CHILDREN_INPUTS_MAPPINGS in attr:
+ call_def.children_inputs_mappings = _json.loads(
+ _compat.as_text(attr[OpHint.CHILDREN_INPUTS_MAPPINGS].s))
+
# Add the input or output
def put_operand(stuff, index, sort, operand, aggregation):
"""Add a given index into the function structure."""
@@ -683,6 +776,98 @@
return func_calls
+def _extract_topology_sequence_mapping(nodes):
+ return dict(
+ (_tensor_name_base(node.name), idx) for idx, node in enumerate(nodes))
+
+
+def _find_children_hints_in_while_loop(function_def, nodes_mapping):
+ """Find children hints and all nodes inside the while loop.
+
+ Args:
+ function_def: Function def of the while loop.
+ nodes_mapping: While loop input_arg : real node name.
+
+ Returns:
+ Ordered children hints and all re-mapped nodes inside the while loop.
+ """
+ new_nodes = []
+
+ # Make nodes inside function def inputs point to the real nodes.
+ for node in function_def.node_def:
+ for i in range(len(node.input)):
+ if node.input[i] in nodes_mapping:
+ node.input[i] = nodes_mapping[node.input[i]]
+ new_nodes.append(_copy.deepcopy(node))
+ name_to_seq_num = _extract_topology_sequence_mapping(function_def.node_def)
+ children_hints = _find_all_hints_in_nodes(new_nodes)
+ children_hints_q = []
+ # Ordered by the outputs.
+ for hint in _six.itervalues(children_hints):
+ _, output_names = hint.flattened_inputs_and_outputs()
+ seq = name_to_seq_num[output_names[0]]
+ for output_name in output_names:
+ seq = min(seq, name_to_seq_num[output_name])
+ children_hints_q.append((seq, hint))
+ children_hints_q.sort(key=lambda tup: tup[0])
+ ordered_children_hints = [x[1] for x in children_hints_q]
+ return ordered_children_hints, new_nodes
+
+
+def _find_children_hints(call, graph_def):
+ """Find all children hints.
+
+ For a given OpHint, we find all children hints inside it, we also copy all the
+ nodes inside function defs (if applicable) to the original graph_def, they are
+ returned in a list as well.
+
+ Args:
+ call: Parent OpHint that contains children ophints.
+ graph_def: Original graph def.
+
+ Returns:
+ Ordered children hints inside the parent ophint; new graph def that contains
+ nodes inside function defs (if applicable); nodes inside function defs.
+ """
+ name_to_input_name, _, _ = _extract_graph_summary(graph_def)
+ input_names, output_names = call.flattened_inputs_and_outputs()
+
+ reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name)
+ reachable_by_output = _bfs_for_reachable_nodes(output_names,
+ name_to_input_name)
+ output_nodes_set = set(output_names)
+ children_hints = []
+ out = _graph_pb2.GraphDef()
+ out.library.CopyFrom(graph_def.library)
+ out.versions.CopyFrom(graph_def.versions)
+ function_def_nodes = set()
+ for node in graph_def.node:
+ out.node.extend([_copy.deepcopy(node)])
+ n = _tensor_name_base(node.name)
+ if n in reachable_by_output:
+ if n not in reachable_by_input and n not in output_nodes_set:
+ # special handle for while loop function def.
+ if node.op == "While":
+ body_name = node.attr["body"].func.name
+ inputs_outside_loop = node.input
+ for function_def in graph_def.library.function:
+ if function_def.signature.name == body_name:
+ function_inputs = function_def.signature.input_arg
+ assert len(inputs_outside_loop) == len(function_inputs)
+ nodes_mapping = {}
+ for i in range(len(function_inputs)):
+ nodes_mapping[function_inputs[i].name] = inputs_outside_loop[i]
+ # TODO(b/123050804): Consider use grappler.
+ (children_hints_in_loop,
+ new_nodes) = _find_children_hints_in_while_loop(
+ function_def, nodes_mapping)
+ function_def_nodes.update([x.name for x in new_nodes])
+ children_hints.extend(children_hints_in_loop)
+ out.node.extend(new_nodes)
+
+ return children_hints, out, function_def_nodes
+
+
def _tensor_name_base(full_tensor_name):
"""Removes the device assignment code from a tensor.
@@ -735,12 +920,20 @@
# TODO(aselle): This should be converted to grappler in the future.
-def _convert_single_op_hint_to_stub(call, graph_def):
+def _convert_single_op_hint_to_stub(call,
+ graph_def,
+ function_def_nodes=None,
+ is_last_run=True):
"""Given a graph_def, converts `call` into a stub and returns a new graph_def.
Args:
call: A single function call to be converted.
graph_def: A graph_def to use as input (that hass call obviously).
+ function_def_nodes: Nodes inside the function def those are not connected to
+ the graph.
+ is_last_run: Whether it is the last run for a given pass (for OpHint has
+ children).
+
Returns:
A new transformed graph-def that has call as a stub (single op).
@@ -748,6 +941,8 @@
the tensorflow runtime, so all future manipulations are done in graph_def
level.
"""
+ if function_def_nodes is None:
+ function_def_nodes = set()
name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
graph_def)
input_names, output_names = call.flattened_inputs_and_outputs()
@@ -755,7 +950,6 @@
reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name)
reachable_by_output = _bfs_for_reachable_nodes(output_names,
name_to_input_name)
- input_nodes_set = set(input_names)
output_nodes_set = set(output_names)
nodes_after_fuse = []
nodes_deleted_by_fuse = set()
@@ -766,19 +960,16 @@
n = _tensor_name_base(node.name)
if n in reachable_by_output:
if n not in reachable_by_input and n not in output_nodes_set:
- # n is an internal node. Check to make sure it is really internal.
- # TODO(aselle): this could be done more efficiently by flooding
- # the graph first.
- _check_subgraph_closed(n, reachable_by_input, input_nodes_set,
- name_to_input_name)
nodes_deleted_by_fuse.add(n)
- elif n not in reachable_by_input:
+ elif n not in reachable_by_input and n not in function_def_nodes:
# n is a node that after all the fusings, so keep it.
nodes_after_fuse.append(n)
else:
- # n is a node that is randomly in the graph but not connected to
- # the chain of dependencies.
- pass
+ # In the last run, n is a node that is randomly in the graph but not
+ # connected to the chain of dependencies, we will delete n, otherwise
+ # we keep them.
+ if not is_last_run:
+ nodes_after_fuse.append(n)
# Make a new graphdef with all the pre-input and input nodes
out = _graph_pb2.GraphDef()
@@ -800,7 +991,8 @@
# non-fused things.
for input_index in sorted_input_indices:
inputs = call.inputs[input_index]
- new_node.input.append(inputs.aggregate_and_return_name_for_input(out))
+ input_name = inputs.aggregate_and_return_name_for_input(out)
+ new_node.input.append(input_name)
new_node.attr[OpHint.TFLITE_INPUT_INDICES].list.i.extend(sorted_input_indices)
# Ceate the function
@@ -936,6 +1128,18 @@
return curr
+def _get_correct_mapping(original_index, nodes):
+ # Special handle for the index is -1 case.
+ # If it is -1, return the last index.
+ if original_index == -1:
+ node_indices = nodes.keys()
+ node_indices.sort()
+ return node_indices[-1]
+ else:
+ return original_index
+ return original_index
+
+
@_tf_export("lite.convert_op_hints_to_stubs")
def _convert_op_hints_to_stubs_helper(
graph_def, write_callback=lambda sess, graph_def: None):
@@ -948,14 +1152,67 @@
Returns:
A new stubbed graph_def.
"""
+ hints = _find_all_hints_in_nodes(graph_def.node)
- hints = _find_all_hints_in_graph_def(graph_def)
+ hints_q = []
+ for hint in _six.itervalues(hints):
+ hints_q.append((hint.level, hint.uuid))
+
+ hints_q.sort(key=lambda tup: tup[0])
+ for i in range(len(hints_q) - 1, -1, -1):
+ level, hint_uuid = hints_q[i]
+
curr_graph_def = graph_def
del graph_def # prevent using graph_def again (common source of error)
- for hint in _six.itervalues(hints):
- curr_graph_def = _convert_single_op_hint_to_stub(
- hint, curr_graph_def)
- write_callback(curr_graph_def, "initial")
+ for i in range(len(hints_q) - 1, -1, -1):
+ level, hint_uuid = hints_q[i]
+ if level >= 2:
+ children_hints, curr_graph_def, function_def_nodes = _find_children_hints(
+ hints[hint_uuid], curr_graph_def)
+ # pylint: disable=superfluous-parens
+ assert (len(children_hints) > 0) # pylint: disable=g-explicit-length-test
+ # pylint: enable=superfluous-parens
+
+ # Re-wire the children hints inputs/outputs, so latter child's inputs
+ # connect to previous child node's outputs.
+ children_inputs_mappings = hints[hint_uuid].children_inputs_mappings
+ for j in range(len(children_hints)):
+ child_hint = children_hints[j]
+ if j == 0:
+ for mapping in children_inputs_mappings["parent_first_child_input"]:
+ parent_input_index = _get_correct_mapping(
+ mapping["parent_ophint_input_index"], hints[hint_uuid].inputs)
+ child_input_index = _get_correct_mapping(
+ mapping["first_child_ophint_input_index"], child_hint.inputs)
+ child_hint.inputs[child_input_index] = hints[hint_uuid].inputs[
+ parent_input_index]
+ else:
+ for mapping in children_inputs_mappings[
+ "internal_children_input_output"]:
+ input_index = _get_correct_mapping(mapping["child_input_index"],
+ child_hint.inputs)
+ output_index = _get_correct_mapping(mapping["child_output_index"],
+ children_hints[j - 1].outputs)
+ child_hint.inputs[input_index] = children_hints[
+ j - 1].outputs[output_index]
+ if j == len(children_hints) - 1:
+ for mapping in children_inputs_mappings["parent_last_child_output"]:
+ parent_output_index = _get_correct_mapping(
+ mapping["parent_output_index"], hints[hint_uuid].outputs)
+ child_output_index = _get_correct_mapping(
+ mapping["child_output_index"], child_hint.outputs)
+ child_hint.outputs[child_output_index] = hints[hint_uuid].outputs[
+ parent_output_index]
+
+ for j in range(len(children_hints)):
+ child_hint = children_hints[j]
+ curr_graph_def = _convert_single_op_hint_to_stub(
+ child_hint, curr_graph_def, function_def_nodes,
+ j == len(children_hints) - 1)
+ else:
+ curr_graph_def = _convert_single_op_hint_to_stub(hints[hint_uuid],
+ curr_graph_def)
+ write_callback(curr_graph_def, "initial")
# The stubbing process can create stacks/unstacks in the case of LSTMs
# remove them.
curr_graph_def = _remove_redundant_stack_unstack(curr_graph_def)
@@ -982,9 +1239,9 @@
raise ValueError("Provide only one of session and graph_def.")
hinted_outputs_nodes = []
if session is not None:
- hints = _find_all_hints_in_graph_def(session.graph_def)
+ hints = _find_all_hints_in_nodes(session.graph_def.node)
elif graph_def is not None:
- hints = _find_all_hints_in_graph_def(graph_def)
+ hints = _find_all_hints_in_nodes(graph_def.node)
for hint in _six.itervalues(hints):
_, ouput_nodes = hint.flattened_inputs_and_outputs()
hinted_outputs_nodes.extend(ouput_nodes)
diff --git a/tensorflow/lite/python/optimize/BUILD b/tensorflow/lite/python/optimize/BUILD
index 74b573b..8694ebf 100644
--- a/tensorflow/lite/python/optimize/BUILD
+++ b/tensorflow/lite/python/optimize/BUILD
@@ -13,6 +13,7 @@
deps = [
"//tensorflow/lite:framework",
"//tensorflow/lite/kernels:builtin_ops",
+ "//tensorflow/lite/python/interpreter_wrapper:numpy",
"//tensorflow/lite/python/interpreter_wrapper:python_error_reporter",
"//tensorflow/lite/python/interpreter_wrapper:python_utils",
"//tensorflow/lite/tools/optimize:calibration_reader",
diff --git a/tensorflow/lite/python/optimize/calibration_wrapper.cc b/tensorflow/lite/python/optimize/calibration_wrapper.cc
index d6fe1fb..21f96f8 100644
--- a/tensorflow/lite/python/optimize/calibration_wrapper.cc
+++ b/tensorflow/lite/python/optimize/calibration_wrapper.cc
@@ -22,20 +22,13 @@
#include "tensorflow/lite/interpreter.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/model.h"
+#include "tensorflow/lite/python/interpreter_wrapper/numpy.h"
#include "tensorflow/lite/python/interpreter_wrapper/python_error_reporter.h"
#include "tensorflow/lite/python/interpreter_wrapper/python_utils.h"
#include "tensorflow/lite/tools/optimize/calibration_reader.h"
#include "tensorflow/lite/tools/optimize/calibrator.h"
#include "tensorflow/lite/tools/optimize/quantize_model.h"
-// Disallow Numpy 1.7 deprecated symbols.
-#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
-
-#include <Python.h>
-
-#include "numpy/arrayobject.h"
-#include "numpy/ufuncobject.h"
-
#if PY_MAJOR_VERSION >= 3
#define PY_TO_CPPSTRING PyBytes_AsStringAndSize
#define CPP_TO_PYSTRING PyBytes_FromStringAndSize
diff --git a/tensorflow/lite/python/testdata/BUILD b/tensorflow/lite/python/testdata/BUILD
index 4689c31..2fa08e5 100644
--- a/tensorflow/lite/python/testdata/BUILD
+++ b/tensorflow/lite/python/testdata/BUILD
@@ -32,9 +32,20 @@
],
)
+tf_to_tflite(
+ name = "gather_string",
+ src = "gather.pbtxt",
+ out = "gather_string.tflite",
+ options = [
+ "--input_arrays=input,indices",
+ "--output_arrays=output",
+ ],
+)
+
filegroup(
name = "interpreter_test_data",
srcs = [
+ ":gather_string",
":permute_float",
":permute_uint8",
],
diff --git a/tensorflow/lite/python/testdata/gather.pbtxt b/tensorflow/lite/python/testdata/gather.pbtxt
new file mode 100644
index 0000000..0b1193c
--- /dev/null
+++ b/tensorflow/lite/python/testdata/gather.pbtxt
@@ -0,0 +1,93 @@
+node {
+ name: "input"
+ op: "Placeholder"
+ device: "/device:CPU:0"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_STRING
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 10
+ }
+ }
+ }
+ }
+}
+node {
+ name: "indices"
+ op: "Placeholder"
+ device: "/device:CPU:0"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT64
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: 3
+ }
+ }
+ }
+ }
+}
+node {
+ name: "axis"
+ op: "Const"
+ device: "/device:CPU:0"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 0
+ }
+ }
+ }
+}
+node {
+ name: "output"
+ op: "GatherV2"
+ input: "input"
+ input: "indices"
+ input: "axis"
+ device: "/device:CPU:0"
+ attr {
+ key: "Taxis"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "Tindices"
+ value {
+ type: DT_INT64
+ }
+ }
+ attr {
+ key: "Tparams"
+ value {
+ type: DT_STRING
+ }
+ }
+}
+versions {
+ producer: 27
+}
diff --git a/tensorflow/lite/schema/schema.fbs b/tensorflow/lite/schema/schema.fbs
index fbcb18f..b9c6e98 100644
--- a/tensorflow/lite/schema/schema.fbs
+++ b/tensorflow/lite/schema/schema.fbs
@@ -219,6 +219,8 @@
UNIQUE = 103,
CEIL = 104,
REVERSE_V2 = 105,
+ ADD_N = 106,
+ GATHER_ND = 107,
}
// Options for the builtin operators.
@@ -304,6 +306,8 @@
SplitVOptions,
UniqueOptions,
ReverseV2Options,
+ AddNOptions,
+ GatherNdOptions,
}
enum Padding : byte { SAME, VALID }
@@ -724,6 +728,12 @@
table ReverseV2Options {
}
+table AddNOptions {
+}
+
+table GatherNdOptions {
+}
+
// An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
// builtin, or a string if the operator is custom.
table OperatorCode {
diff --git a/tensorflow/lite/schema/schema_generated.h b/tensorflow/lite/schema/schema_generated.h
index 6ad7df0..3117789 100755
--- a/tensorflow/lite/schema/schema_generated.h
+++ b/tensorflow/lite/schema/schema_generated.h
@@ -274,6 +274,12 @@
struct ReverseV2Options;
struct ReverseV2OptionsT;
+struct AddNOptions;
+struct AddNOptionsT;
+
+struct GatherNdOptions;
+struct GatherNdOptionsT;
+
struct OperatorCode;
struct OperatorCodeT;
@@ -529,11 +535,13 @@
BuiltinOperator_UNIQUE = 103,
BuiltinOperator_CEIL = 104,
BuiltinOperator_REVERSE_V2 = 105,
+ BuiltinOperator_ADD_N = 106,
+ BuiltinOperator_GATHER_ND = 107,
BuiltinOperator_MIN = BuiltinOperator_ADD,
- BuiltinOperator_MAX = BuiltinOperator_REVERSE_V2
+ BuiltinOperator_MAX = BuiltinOperator_GATHER_ND
};
-inline const BuiltinOperator (&EnumValuesBuiltinOperator())[105] {
+inline const BuiltinOperator (&EnumValuesBuiltinOperator())[107] {
static const BuiltinOperator values[] = {
BuiltinOperator_ADD,
BuiltinOperator_AVERAGE_POOL_2D,
@@ -639,7 +647,9 @@
BuiltinOperator_SPLIT_V,
BuiltinOperator_UNIQUE,
BuiltinOperator_CEIL,
- BuiltinOperator_REVERSE_V2
+ BuiltinOperator_REVERSE_V2,
+ BuiltinOperator_ADD_N,
+ BuiltinOperator_GATHER_ND
};
return values;
}
@@ -752,6 +762,8 @@
"UNIQUE",
"CEIL",
"REVERSE_V2",
+ "ADD_N",
+ "GATHER_ND",
nullptr
};
return names;
@@ -845,11 +857,13 @@
BuiltinOptions_SplitVOptions = 79,
BuiltinOptions_UniqueOptions = 80,
BuiltinOptions_ReverseV2Options = 81,
+ BuiltinOptions_AddNOptions = 82,
+ BuiltinOptions_GatherNdOptions = 83,
BuiltinOptions_MIN = BuiltinOptions_NONE,
- BuiltinOptions_MAX = BuiltinOptions_ReverseV2Options
+ BuiltinOptions_MAX = BuiltinOptions_GatherNdOptions
};
-inline const BuiltinOptions (&EnumValuesBuiltinOptions())[82] {
+inline const BuiltinOptions (&EnumValuesBuiltinOptions())[84] {
static const BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@@ -932,7 +946,9 @@
BuiltinOptions_AbsOptions,
BuiltinOptions_SplitVOptions,
BuiltinOptions_UniqueOptions,
- BuiltinOptions_ReverseV2Options
+ BuiltinOptions_ReverseV2Options,
+ BuiltinOptions_AddNOptions,
+ BuiltinOptions_GatherNdOptions
};
return values;
}
@@ -1021,6 +1037,8 @@
"SplitVOptions",
"UniqueOptions",
"ReverseV2Options",
+ "AddNOptions",
+ "GatherNdOptions",
nullptr
};
return names;
@@ -1359,6 +1377,14 @@
static const BuiltinOptions enum_value = BuiltinOptions_ReverseV2Options;
};
+template<> struct BuiltinOptionsTraits<AddNOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_AddNOptions;
+};
+
+template<> struct BuiltinOptionsTraits<GatherNdOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_GatherNdOptions;
+};
+
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@@ -2038,6 +2064,22 @@
return type == BuiltinOptions_ReverseV2Options ?
reinterpret_cast<const ReverseV2OptionsT *>(value) : nullptr;
}
+ AddNOptionsT *AsAddNOptions() {
+ return type == BuiltinOptions_AddNOptions ?
+ reinterpret_cast<AddNOptionsT *>(value) : nullptr;
+ }
+ const AddNOptionsT *AsAddNOptions() const {
+ return type == BuiltinOptions_AddNOptions ?
+ reinterpret_cast<const AddNOptionsT *>(value) : nullptr;
+ }
+ GatherNdOptionsT *AsGatherNdOptions() {
+ return type == BuiltinOptions_GatherNdOptions ?
+ reinterpret_cast<GatherNdOptionsT *>(value) : nullptr;
+ }
+ const GatherNdOptionsT *AsGatherNdOptions() const {
+ return type == BuiltinOptions_GatherNdOptions ?
+ reinterpret_cast<const GatherNdOptionsT *>(value) : nullptr;
+ }
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@@ -7174,6 +7216,86 @@
flatbuffers::Offset<ReverseV2Options> CreateReverseV2Options(flatbuffers::FlatBufferBuilder &_fbb, const ReverseV2OptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct AddNOptionsT : public flatbuffers::NativeTable {
+ typedef AddNOptions TableType;
+ AddNOptionsT() {
+ }
+};
+
+struct AddNOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef AddNOptionsT NativeTableType;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ AddNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(AddNOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<AddNOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const AddNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct AddNOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit AddNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ AddNOptionsBuilder &operator=(const AddNOptionsBuilder &);
+ flatbuffers::Offset<AddNOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<AddNOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<AddNOptions> CreateAddNOptions(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ AddNOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<AddNOptions> CreateAddNOptions(flatbuffers::FlatBufferBuilder &_fbb, const AddNOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
+struct GatherNdOptionsT : public flatbuffers::NativeTable {
+ typedef GatherNdOptions TableType;
+ GatherNdOptionsT() {
+ }
+};
+
+struct GatherNdOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef GatherNdOptionsT NativeTableType;
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ verifier.EndTable();
+ }
+ GatherNdOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(GatherNdOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<GatherNdOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const GatherNdOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct GatherNdOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ explicit GatherNdOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ GatherNdOptionsBuilder &operator=(const GatherNdOptionsBuilder &);
+ flatbuffers::Offset<GatherNdOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<GatherNdOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<GatherNdOptions> CreateGatherNdOptions(
+ flatbuffers::FlatBufferBuilder &_fbb) {
+ GatherNdOptionsBuilder builder_(_fbb);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<GatherNdOptions> CreateGatherNdOptions(flatbuffers::FlatBufferBuilder &_fbb, const GatherNdOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct OperatorCodeT : public flatbuffers::NativeTable {
typedef OperatorCode TableType;
BuiltinOperator builtin_code;
@@ -7550,6 +7672,12 @@
const ReverseV2Options *builtin_options_as_ReverseV2Options() const {
return builtin_options_type() == BuiltinOptions_ReverseV2Options ? static_cast<const ReverseV2Options *>(builtin_options()) : nullptr;
}
+ const AddNOptions *builtin_options_as_AddNOptions() const {
+ return builtin_options_type() == BuiltinOptions_AddNOptions ? static_cast<const AddNOptions *>(builtin_options()) : nullptr;
+ }
+ const GatherNdOptions *builtin_options_as_GatherNdOptions() const {
+ return builtin_options_type() == BuiltinOptions_GatherNdOptions ? static_cast<const GatherNdOptions *>(builtin_options()) : nullptr;
+ }
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@@ -7905,6 +8033,14 @@
return builtin_options_as_ReverseV2Options();
}
+template<> inline const AddNOptions *Operator::builtin_options_as<AddNOptions>() const {
+ return builtin_options_as_AddNOptions();
+}
+
+template<> inline const GatherNdOptions *Operator::builtin_options_as<GatherNdOptions>() const {
+ return builtin_options_as_GatherNdOptions();
+}
+
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@@ -10575,6 +10711,52 @@
_fbb);
}
+inline AddNOptionsT *AddNOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new AddNOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void AddNOptions::UnPackTo(AddNOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<AddNOptions> AddNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const AddNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateAddNOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<AddNOptions> CreateAddNOptions(flatbuffers::FlatBufferBuilder &_fbb, const AddNOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const AddNOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return tflite::CreateAddNOptions(
+ _fbb);
+}
+
+inline GatherNdOptionsT *GatherNdOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new GatherNdOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void GatherNdOptions::UnPackTo(GatherNdOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+}
+
+inline flatbuffers::Offset<GatherNdOptions> GatherNdOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const GatherNdOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateGatherNdOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<GatherNdOptions> CreateGatherNdOptions(flatbuffers::FlatBufferBuilder &_fbb, const GatherNdOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const GatherNdOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ return tflite::CreateGatherNdOptions(
+ _fbb);
+}
+
inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new OperatorCodeT();
UnPackTo(_o, _resolver);
@@ -11157,6 +11339,14 @@
auto ptr = reinterpret_cast<const ReverseV2Options *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_AddNOptions: {
+ auto ptr = reinterpret_cast<const AddNOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_GatherNdOptions: {
+ auto ptr = reinterpret_cast<const GatherNdOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default: return false;
}
}
@@ -11499,6 +11689,14 @@
auto ptr = reinterpret_cast<const ReverseV2Options *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_AddNOptions: {
+ auto ptr = reinterpret_cast<const AddNOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
+ case BuiltinOptions_GatherNdOptions: {
+ auto ptr = reinterpret_cast<const GatherNdOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
default: return nullptr;
}
}
@@ -11829,6 +12027,14 @@
auto ptr = reinterpret_cast<const ReverseV2OptionsT *>(value);
return CreateReverseV2Options(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_AddNOptions: {
+ auto ptr = reinterpret_cast<const AddNOptionsT *>(value);
+ return CreateAddNOptions(_fbb, ptr, _rehasher).Union();
+ }
+ case BuiltinOptions_GatherNdOptions: {
+ auto ptr = reinterpret_cast<const GatherNdOptionsT *>(value);
+ return CreateGatherNdOptions(_fbb, ptr, _rehasher).Union();
+ }
default: return 0;
}
}
@@ -12159,6 +12365,14 @@
value = new ReverseV2OptionsT(*reinterpret_cast<ReverseV2OptionsT *>(u.value));
break;
}
+ case BuiltinOptions_AddNOptions: {
+ value = new AddNOptionsT(*reinterpret_cast<AddNOptionsT *>(u.value));
+ break;
+ }
+ case BuiltinOptions_GatherNdOptions: {
+ value = new GatherNdOptionsT(*reinterpret_cast<GatherNdOptionsT *>(u.value));
+ break;
+ }
default:
break;
}
@@ -12571,6 +12785,16 @@
delete ptr;
break;
}
+ case BuiltinOptions_AddNOptions: {
+ auto ptr = reinterpret_cast<AddNOptionsT *>(value);
+ delete ptr;
+ break;
+ }
+ case BuiltinOptions_GatherNdOptions: {
+ auto ptr = reinterpret_cast<GatherNdOptionsT *>(value);
+ delete ptr;
+ break;
+ }
default: break;
}
value = nullptr;
diff --git a/tensorflow/lite/stderr_reporter.cc b/tensorflow/lite/stderr_reporter.cc
index 09eb1d2..366a181 100644
--- a/tensorflow/lite/stderr_reporter.cc
+++ b/tensorflow/lite/stderr_reporter.cc
@@ -13,28 +13,14 @@
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/stderr_reporter.h"
-#include <cstdarg>
-#include <cstdio>
-#ifdef __ANDROID__
-#include <android/log.h>
-#endif
+#include "tensorflow/lite/minimal_logging.h"
namespace tflite {
int StderrReporter::Report(const char* format, va_list args) {
-#ifdef __ANDROID__
- // On Android stderr is not captured for applications, only for code run from
- // the shell. Rather than assume all users will set up a custom error
- // reporter, let's output to logcat here
- va_list args_for_log;
- va_copy(args_for_log, args);
- __android_log_vprint(ANDROID_LOG_ERROR, "tflite", format, args_for_log);
- va_end(args_for_log);
-#endif
- const int result = vfprintf(stderr, format, args);
- fputc('\n', stderr);
- return result;
+ logging_internal::MinimalLogger::VLog(TFLITE_LOG_ERROR, format, args);
+ return 0;
}
ErrorReporter* DefaultErrorReporter() {
diff --git a/tensorflow/lite/string_util.h b/tensorflow/lite/string_util.h
index f076db7..cb268ee 100644
--- a/tensorflow/lite/string_util.h
+++ b/tensorflow/lite/string_util.h
@@ -35,7 +35,7 @@
// buf.AddString("AB", 2);
// # Write content of DynamicBuffer to tensor in format of string tensor
// # described above.
-// buf.WriteToTensor(tensor)
+// buf.WriteToTensor(tensor, nullptr)
#ifndef TENSORFLOW_LITE_STRING_UTIL_H_
#define TENSORFLOW_LITE_STRING_UTIL_H_
@@ -83,10 +83,6 @@
// Fill content into a string tensor. Set shape to {num_strings}.
void WriteToTensorAsVector(TfLiteTensor* tensor);
- // Deprecated. Use WriteToTensorAsVector() or pass in the new shpe.
- // TODO(b/120230709): remove when people migrate away.
- void WriteToTensor(TfLiteTensor* tensor) { WriteToTensorAsVector(tensor); }
-
private:
// Data buffer to store contents of strings, not including headers.
std::vector<char> data_;
diff --git a/tensorflow/lite/testing/generate_examples.py b/tensorflow/lite/testing/generate_examples.py
index ee68607..7ebc595 100644
--- a/tensorflow/lite/testing/generate_examples.py
+++ b/tensorflow/lite/testing/generate_examples.py
@@ -1184,6 +1184,51 @@
make_binary_op_tests(zip_path, tf.add)
+def make_add_n_tests(zip_path):
+ """Make a set of tests for AddN op."""
+
+ test_parameters = [
+ {
+ "dtype": [tf.float32, tf.int32],
+ "input_shape": [[2, 5, 3, 1]],
+ "num_inputs": [2, 3, 4, 5],
+ },
+ {
+ "dtype": [tf.float32, tf.int32],
+ "input_shape": [[5]],
+ "num_inputs": [2, 3, 4, 5],
+ },
+ {
+ "dtype": [tf.float32, tf.int32],
+ "input_shape": [[]],
+ "num_inputs": [2, 3, 4, 5],
+ },
+ ]
+
+ def build_graph(parameters):
+ """Builds the graph given the current parameters."""
+ input_tensors = []
+ for i in range(parameters["num_inputs"]):
+ input_tensors.append(
+ tf.placeholder(
+ dtype=parameters["dtype"],
+ name="input_{}".format(i),
+ shape=parameters["input_shape"]))
+ out = tf.add_n(input_tensors)
+ return input_tensors, [out]
+
+ def build_inputs(parameters, sess, inputs, outputs):
+ """Builds operand inputs for op."""
+ input_data = []
+ for i in range(parameters["num_inputs"]):
+ input_data.append(
+ create_tensor_data(parameters["dtype"], parameters["input_shape"]))
+ return input_data, sess.run(
+ outputs, feed_dict={i: d for i, d in zip(inputs, input_data)})
+
+ make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
def make_div_tests(zip_path):
make_binary_op_tests(zip_path, tf.div)
diff --git a/tensorflow/lite/tflite_exported_symbols.lds b/tensorflow/lite/tflite_exported_symbols.lds
new file mode 100644
index 0000000..b145204
--- /dev/null
+++ b/tensorflow/lite/tflite_exported_symbols.lds
@@ -0,0 +1,3 @@
+*TfLite*
+*tflite*
+*TFL_*
diff --git a/tensorflow/lite/tflite_version_script.lds b/tensorflow/lite/tflite_version_script.lds
new file mode 100644
index 0000000..1df7070
--- /dev/null
+++ b/tensorflow/lite/tflite_version_script.lds
@@ -0,0 +1,8 @@
+VERS_1.0 {
+ global:
+ *TfLite*;
+ *tflite*;
+ *TFL_*;
+ local:
+ *;
+};
diff --git a/tensorflow/lite/toco/graph_transformations/graph_transformations.cc b/tensorflow/lite/toco/graph_transformations/graph_transformations.cc
index a0260e2..e4eb769 100644
--- a/tensorflow/lite/toco/graph_transformations/graph_transformations.cc
+++ b/tensorflow/lite/toco/graph_transformations/graph_transformations.cc
@@ -128,7 +128,8 @@
}
bool GraphTransformationsPass(int increment, Model* model,
- const GraphTransformationsSet& transformations) {
+ const GraphTransformationsSet& transformations,
+ tensorflow::Status* status) {
CHECK(increment == 1 || increment == -1);
bool changed = false;
if (model->operators.empty()) {
@@ -142,7 +143,10 @@
for (const auto& transformation : transformations) {
CHECK(!changed_now);
CHECK(transformation->Messages().empty());
- CHECK(transformation->Run(model, op_index, &changed_now).ok());
+ *status = transformation->Run(model, op_index, &changed_now);
+ if (!status->ok()) {
+ return false;
+ }
const char* made_a_change_msg =
changed_now ? "made a change" : "did NOT make a change";
const int log_level =
@@ -186,18 +190,21 @@
} // namespace
-void RunGraphTransformations(Model* model, const string& msg,
- const GraphTransformationsSet& transformations) {
+tensorflow::Status RunGraphTransformationsWithStatus(
+ Model* model, const string& msg,
+ const GraphTransformationsSet& transformations) {
PrintModelStats(toco::port::StringF("Before %s", msg), *model);
int pass_index = 0;
+ tensorflow::Status status;
while (GraphTransformationsPass((pass_index % 2) ? -1 : 1, model,
- transformations)) {
+ transformations, &status)) {
pass_index++;
const auto& label =
toco::port::StringF("After %s pass %d", msg, pass_index);
PrintModelStats(label, *model);
CheckInvariants(*model);
}
+ return status;
}
} // namespace toco
diff --git a/tensorflow/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/lite/toco/graph_transformations/graph_transformations.h
index 4008bbd..491a3e7 100644
--- a/tensorflow/lite/toco/graph_transformations/graph_transformations.h
+++ b/tensorflow/lite/toco/graph_transformations/graph_transformations.h
@@ -102,8 +102,16 @@
// construct GraphTransformation objects by using 'new', pass us
// the resulting raw pointers, and this RunGraphTransformations
// takes care of delete'ing these pointers.
-void RunGraphTransformations(Model* model, const string& message,
- const GraphTransformationsSet& transformations);
+tensorflow::Status RunGraphTransformationsWithStatus(
+ Model* model, const string& msg,
+ const GraphTransformationsSet& transformations);
+
+inline void RunGraphTransformations(
+ Model* model, const string& msg,
+ const GraphTransformationsSet& transformations) {
+ auto s = RunGraphTransformationsWithStatus(model, msg, transformations);
+ CHECK(s.ok()) << s.error_message();
+}
#define DECLARE_GRAPH_TRANSFORMATION(GTName) \
class GTName : public GraphTransformation { \
diff --git a/tensorflow/lite/toco/graph_transformations/quantize.cc b/tensorflow/lite/toco/graph_transformations/quantize.cc
index ee65f92..5a5c9bb 100644
--- a/tensorflow/lite/toco/graph_transformations/quantize.cc
+++ b/tensorflow/lite/toco/graph_transformations/quantize.cc
@@ -489,12 +489,12 @@
}
}
if (!SupportsQuantization(op)) {
- LOG(FATAL) << "Unimplemented: this graph contains an operator of type "
- << HelpfulOperatorTypeName(op)
- << " for which the quantized form is not yet implemented. "
- "Sorry, and patches welcome (that's a relatively fun patch "
- "to write, mostly providing the actual quantized arithmetic "
- "code for this op).";
+ return tensorflow::errors::InvalidArgument(
+ "Unimplemented: this graph contains an operator of type ",
+ HelpfulOperatorTypeName(op),
+ " for which the quantized form is not yet implemented. Sorry, and "
+ "patches welcome (that's a relatively fun patch to write, mostly "
+ "providing the actual quantized arithmetic code for this op).");
}
for (const auto& input : op.inputs) {
diff --git a/tensorflow/lite/toco/import_tensorflow.cc b/tensorflow/lite/toco/import_tensorflow.cc
index b1b0494..813e439 100644
--- a/tensorflow/lite/toco/import_tensorflow.cc
+++ b/tensorflow/lite/toco/import_tensorflow.cc
@@ -2375,7 +2375,7 @@
return std::unordered_map<std::string, ConverterType>({
{"Abs", ConvertSimpleOperator<AbsOperator, kAnyNumInputs, 1>},
{"Add", ConvertSimpleOperator<AddOperator, 2, 1>},
- {"AddN", ConvertSimpleOperatorFlexOk<AddNOperator, kAnyNumInputs, 1>},
+ {"AddN", ConvertSimpleOperator<AddNOperator, kAnyNumInputs, 1>},
{"All", ConvertSimpleOperator<TensorFlowAllOperator, kAnyNumInputs, 1>},
{"Any", ConvertReduceOperator<TensorFlowAnyOperator>},
{"ArgMax", ConvertArgMaxOperator},
diff --git a/tensorflow/lite/toco/tflite/operator.cc b/tensorflow/lite/toco/tflite/operator.cc
index f61488e..58c4a85 100644
--- a/tensorflow/lite/toco/tflite/operator.cc
+++ b/tensorflow/lite/toco/tflite/operator.cc
@@ -108,6 +108,12 @@
const Array& input_array = op_signature.model->GetArray(input_name);
const Array& filter_array = op_signature.model->GetArray(filter_name);
const Array& output_array = op_signature.model->GetArray(output_name);
+ // If the op has signed int8 inputs and outputs, its version 3.
+ if (input_array.data_type == ArrayDataType::kInt8 &&
+ filter_array.data_type == ArrayDataType::kInt8 &&
+ output_array.data_type == ArrayDataType::kInt8) {
+ return 3;
+ }
// If the op is a signed int8 hybrid operation, we need to return
// version 2.
if (input_array.data_type == ArrayDataType::kFloat &&
@@ -153,6 +159,18 @@
int GetVersion(const OperatorSignature& op_signature) const override {
const auto& conv_op =
static_cast<const DepthwiseConvOperator&>(*op_signature.op);
+ const string& input_name = op_signature.op->inputs[0];
+ const string& filter_name = op_signature.op->inputs[1];
+ const string& output_name = op_signature.op->outputs[0];
+ const Array& input_array = op_signature.model->GetArray(input_name);
+ const Array& filter_array = op_signature.model->GetArray(filter_name);
+ const Array& output_array = op_signature.model->GetArray(output_name);
+ // If the op has signed int8 inputs and outputs, its version 3.
+ if (input_array.data_type == ArrayDataType::kInt8 &&
+ filter_array.data_type == ArrayDataType::kInt8 &&
+ output_array.data_type == ArrayDataType::kInt8) {
+ return 3;
+ }
if (conv_op.dilation_width_factor != 1 ||
conv_op.dilation_height_factor != 1) {
return 2;
@@ -185,6 +203,25 @@
}
};
+class AddN : public BuiltinOperator<AddNOperator, ::tflite::AddNOptions,
+ ::tflite::BuiltinOptions_AddNOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateAddNOptions(*builder);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {}
+
+ int GetVersion(const OperatorSignature& op_signature) const override {
+ return 1;
+ }
+};
+
class SpaceToBatchND
: public BuiltinOperator<SpaceToBatchNDOperator,
::tflite::SpaceToBatchNDOptions,
@@ -449,6 +486,12 @@
}
int GetVersion(const OperatorSignature& op_signature) const override {
+ const string& input_name = op_signature.op->inputs[0];
+ const Array& input_array = op_signature.model->GetArray(input_name);
+ // If the op take int8 input, it is version 2.
+ if (input_array.data_type == ArrayDataType::kInt8) {
+ return 2;
+ }
return 1;
}
};
@@ -1216,6 +1259,12 @@
}
int GetVersion(const OperatorSignature& op_signature) const override {
+ const string& input_name = op_signature.op->inputs[0];
+ const Array& input_array = op_signature.model->GetArray(input_name);
+ // If the op take int8 input, it is version 2.
+ if (input_array.data_type == ArrayDataType::kInt8) {
+ return 2;
+ }
return 1;
}
};
@@ -1305,6 +1354,12 @@
}
int GetVersion(const OperatorSignature& op_signature) const override {
+ const string& input_name = op_signature.op->inputs[0];
+ const Array& input_array = op_signature.model->GetArray(input_name);
+ if (input_array.data_type == ArrayDataType::kInt8) {
+ return 2;
+ }
+
return 1;
}
};
@@ -1326,6 +1381,12 @@
}
int GetVersion(const OperatorSignature& op_signature) const override {
+ const string& input_name = op_signature.op->inputs[0];
+ const Array& input_array = op_signature.model->GetArray(input_name);
+ if (input_array.data_type == ArrayDataType::kInt8) {
+ return 2;
+ }
+
return 1;
}
};
@@ -1880,6 +1941,8 @@
ops.push_back(
MakeUnique<Add>(::tflite::BuiltinOperator_ADD, OperatorType::kAdd));
ops.push_back(
+ MakeUnique<AddN>(::tflite::BuiltinOperator_ADD_N, OperatorType::kAddN));
+ ops.push_back(
MakeUnique<Div>(::tflite::BuiltinOperator_DIV, OperatorType::kDiv));
ops.push_back(
MakeUnique<Sub>(::tflite::BuiltinOperator_SUB, OperatorType::kSub));
diff --git a/tensorflow/lite/toco/tflite/operator_test.cc b/tensorflow/lite/toco/tflite/operator_test.cc
index 88f68f7..43b52c4 100644
--- a/tensorflow/lite/toco/tflite/operator_test.cc
+++ b/tensorflow/lite/toco/tflite/operator_test.cc
@@ -164,6 +164,13 @@
output_toco_op->fused_activation_function);
}
+TEST_F(OperatorTest, BuiltinAddN) {
+ AddNOperator op;
+ auto output_toco_op =
+ SerializeAndDeserialize(GetOperator("ADD_N", OperatorType::kAddN), op);
+ ASSERT_NE(output_toco_op.get(), nullptr);
+}
+
TEST_F(OperatorTest, BuiltinReducerOps) {
CheckReducerOperator<MeanOperator>("MEAN", OperatorType::kMean);
CheckReducerOperator<TensorFlowSumOperator>("SUM", OperatorType::kSum);
diff --git a/tensorflow/lite/toco/toco_convert.cc b/tensorflow/lite/toco/toco_convert.cc
index 28e7b10..2adfc1d 100644
--- a/tensorflow/lite/toco/toco_convert.cc
+++ b/tensorflow/lite/toco/toco_convert.cc
@@ -77,7 +77,7 @@
string* output_file_contents) {
std::unique_ptr<Model> model =
Import(toco_flags, model_flags, graph_def_contents);
- Transform(toco_flags, model.get());
+ TF_RETURN_IF_ERROR(TransformWithStatus(toco_flags, model.get()));
return Export(toco_flags, *model, toco_flags.allow_custom_ops(),
output_file_contents);
}
diff --git a/tensorflow/lite/toco/toco_tooling.cc b/tensorflow/lite/toco/toco_tooling.cc
index 69d7a7a..06f5182 100644
--- a/tensorflow/lite/toco/toco_tooling.cc
+++ b/tensorflow/lite/toco/toco_tooling.cc
@@ -236,7 +236,8 @@
return model;
}
-void Transform(const TocoFlags& toco_flags, Model* model) {
+tensorflow::Status TransformWithStatus(const TocoFlags& toco_flags,
+ Model* model) {
const FileFormat output_format = toco_flags.output_format();
const IODataType inference_type = toco_flags.inference_type();
@@ -258,8 +259,8 @@
// stop optimizations from crossing the input/output boundaries. For example
// this will stop BatchNorm fusing if the output node is in between a conv
// and BatchNorm layers.
- RunGraphTransformations(model, "Removing unused ops",
- {new toco::RemoveUnusedOp});
+ TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
+ model, "Removing unused ops", {new toco::RemoveUnusedOp}));
GraphTransformationsSet transformations;
MakeGeneralGraphTransformationsSet(&transformations);
@@ -307,20 +308,21 @@
identify_dilated_conv->set_identify_depthwise_conv(false);
}
transformations.Add(identify_dilated_conv);
- RunGraphTransformations(model, "general graph transformations",
- transformations);
+ TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
+ model, "general graph transformations", transformations));
if (quantize_output) {
if (toco_flags.propagate_fake_quant_num_bits()) {
- RunGraphTransformations(model,
- "fake quant propagation graph transformations",
- {new PropagateFakeQuantNumBits});
+ TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
+ model, "fake quant propagation graph transformations",
+ {new PropagateFakeQuantNumBits}));
}
- RunGraphTransformations(model, "pre-quantization graph transformations",
- {
- new HardcodeMinMax,
- new DropFakeQuant,
- });
+ TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
+ model, "pre-quantization graph transformations",
+ {
+ new HardcodeMinMax,
+ new DropFakeQuant,
+ }));
}
// Try to merge bidirectional sequence lstm or rnn if present.
@@ -328,8 +330,9 @@
bidirectional_transformations.Add(new RemoveUnusedOp);
bidirectional_transformations.Add(new toco::GroupBidirectionalSequenceLstm);
bidirectional_transformations.Add(new toco::GroupBidirectionalSequenceRnn);
- RunGraphTransformations(model, "Group bidirectional sequence lstm/rnn",
- bidirectional_transformations);
+ TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
+ model, "Group bidirectional sequence lstm/rnn",
+ bidirectional_transformations));
// Fix any issues with IO edges. This must happen after any transform that
// may modify the structure of the edges.
@@ -357,12 +360,12 @@
toco_flags.default_int16_ranges_max());
}
if (propagate_default_min_max->has_any_ranges_defined()) {
- RunGraphTransformations(
+ TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
model, "default min-max range propagation graph transformations",
{
propagate_default_min_max.release(),
new HardcodeMinMax,
- });
+ }));
}
CheckIsReadyForQuantization(*model);
@@ -372,17 +375,18 @@
toco_flags.allow_nudging_weights_to_use_fast_gemm_kernel());
ensure_safe_for_int8_kernels->set_has_default_ranges_flag(
has_default_ranges_flag);
- RunGraphTransformations(model, "quantization graph transformations",
- {
- new RemoveTrivialQuantizedActivationFunc,
- new RemoveTrivialQuantizedMinMax,
- new Quantize,
- new RemoveFinalDequantizeOp,
- ensure_safe_for_int8_kernels,
- });
+ TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
+ model, "quantization graph transformations",
+ {
+ new RemoveTrivialQuantizedActivationFunc,
+ new RemoveTrivialQuantizedMinMax,
+ new Quantize,
+ new RemoveFinalDequantizeOp,
+ ensure_safe_for_int8_kernels,
+ }));
if (SupportsShuffledFCWeights(output_format)) {
- RunGraphTransformations(model, "shuffling of FC weights",
- {new ShuffleFCWeights});
+ TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
+ model, "shuffling of FC weights", {new ShuffleFCWeights}));
}
} else {
GraphTransformationsSet dequantization_transformations{new Dequantize};
@@ -392,8 +396,9 @@
dequantization_transformations.Add(new DropFakeQuant);
}
- RunGraphTransformations(model, "dequantization graph transformations",
- dequantization_transformations);
+ TF_RETURN_IF_ERROR(RunGraphTransformationsWithStatus(
+ model, "dequantization graph transformations",
+ dequantization_transformations));
}
if (output_format == TENSORFLOW_GRAPHDEF) {
@@ -425,6 +430,7 @@
<< " billion (note that a multiply-add is counted as 2 ops).";
}
model->ops_count = ops_count;
+ return tensorflow::Status::OK();
}
tensorflow::Status Export(const TocoFlags& toco_flags, const Model& model,
diff --git a/tensorflow/lite/toco/toco_tooling.h b/tensorflow/lite/toco/toco_tooling.h
index 742e376..3699615 100644
--- a/tensorflow/lite/toco/toco_tooling.h
+++ b/tensorflow/lite/toco/toco_tooling.h
@@ -31,7 +31,12 @@
// Transforms a Model. The resulting Model is ready to be passed
// to Export with the exact same toco_flags.
-void Transform(const TocoFlags& toco_flags, Model* model);
+tensorflow::Status TransformWithStatus(const TocoFlags& toco_flags,
+ Model* model);
+inline void Transform(const TocoFlags& toco_flags, Model* model) {
+ auto s = TransformWithStatus(toco_flags, model);
+ CHECK(s.ok()) << s.error_message();
+}
// Exports the Model, which must be of the 'lowered' form returned by
// Transform, to a file of the format given by
diff --git a/tensorflow/lite/tools/accuracy/ilsvrc/README.md b/tensorflow/lite/tools/accuracy/ilsvrc/README.md
index ac3a156..28ad2e4 100644
--- a/tensorflow/lite/tools/accuracy/ilsvrc/README.md
+++ b/tensorflow/lite/tools/accuracy/ilsvrc/README.md
@@ -16,18 +16,25 @@
The path to the directory containing ground truth images.
* `ground_truth_labels`: `string` \
- Path to ground truth labels file. This file should contain the same number of labels as the number images in the ground truth directory. The labels are assumed to be in the
- same order as the sorted filename of images. See [ground truth label generation](#ground-truth-label-generation)
- section for more information about how to generate labels for images.
+ Path to ground truth labels file. This file should contain the same number
+ of labels as the number images in the ground truth directory. The labels are
+ assumed to be in the same order as the sorted filename of images. See
+ [ground truth label generation](#ground-truth-label-generation) section for
+ more information about how to generate labels for images.
-* `model_output_labels`: `string` \
+* `model_output_labels`: `string` \
Path to the file containing labels, that is used to interpret the output of
the model. E.g. in case of mobilenets, this is the path to
`mobilenet_labels.txt` where each label is in the same order as the output
1001 dimension tensor.
* `output_path`: `string` \
- This is the path to the output file. The output is a CSV file that has top-10 accuracies in each row. Each line of output file is the cumulative accuracy after processing images in a sorted order. So first line is accuracy after processing the first image, second line is accuracy after procesing first two images. The last line of the file is accuracy after processing the entire validation set.
+ This is the path to the output file. The output is a CSV file that has
+ top-10 accuracies in each row. Each line of output file is the cumulative
+ accuracy after processing images in a sorted order. So first line is
+ accuracy after processing the first image, second line is accuracy after
+ processing first two images. The last line of the file is accuracy after
+ processing the entire validation set.
and the following optional parameters:
diff --git a/tensorflow/lite/tools/accuracy/ilsvrc/inception_preprocessing.cc b/tensorflow/lite/tools/accuracy/ilsvrc/inception_preprocessing.cc
index 04b6cb7..b730b08 100644
--- a/tensorflow/lite/tools/accuracy/ilsvrc/inception_preprocessing.cc
+++ b/tensorflow/lite/tools/accuracy/ilsvrc/inception_preprocessing.cc
@@ -19,7 +19,6 @@
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/standard_ops.h"
-#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/core/errors.h"
diff --git a/tensorflow/lite/tools/benchmark/ios/README.md b/tensorflow/lite/tools/benchmark/ios/README.md
index 8142f48..ee880f0 100644
--- a/tensorflow/lite/tools/benchmark/ios/README.md
+++ b/tensorflow/lite/tools/benchmark/ios/README.md
@@ -27,9 +27,9 @@
tensorflow/lite/tools/make/build_ios_universal_lib.sh
```
-will also build `tensorflow/lite/gen/lib/benchmark-lib.a` .
+will also build `tensorflow/lite/tools/make/gen/lib/benchmark-lib.a` .
-- Now copy the downloaded model file to `benchmark_data` directory.
+- Now copy the downloaded model file to `benchmark_data` directory.
- Modify `benchmark_params.json` change the `input_layer`, `input_layer_shape`
and other benchmark parameters.
@@ -37,8 +37,8 @@
- Change `Build Phases -> Copy Bundle Resources` and add the model file to the
resources that need to be copied.
-- Ensure that `Build Phases -> Link Binary With Library` contains the
-`Accelerate framework` and `tensorflow/lite/gen/lib/benchmark-lib.a`.
+- Ensure that `Build Phases -> Link Binary With Library` contains the
+`Accelerate framework` and `tensorflow/lite/tools/make/gen/lib/benchmark-lib.a`.
- Now try running the app. The app has a single button that runs the benchmark
on the model and displays results in a text view below.
@@ -48,7 +48,7 @@
If you want detailed profiling, use the following command:
```bash
-tensorflow/lite/build_ios_universal_lib.sh -p
+tensorflow/lite/tools/make/build_ios_universal_lib.sh -p
```
Then following the same steps above and run the benchmark app. You will see the
diff --git a/tensorflow/lite/tools/make/Makefile b/tensorflow/lite/tools/make/Makefile
index e98ba9b..5c4bb4d 100644
--- a/tensorflow/lite/tools/make/Makefile
+++ b/tensorflow/lite/tools/make/Makefile
@@ -131,6 +131,14 @@
CORE_CC_EXCLUDE_SRCS += tensorflow/lite/nnapi_delegate.cc
endif
+ifeq ($(TARGET),ios)
+ CORE_CC_EXCLUDE_SRCS += tensorflow/lite/minimal_logging_android.cc
+ CORE_CC_EXCLUDE_SRCS += tensorflow/lite/minimal_logging_default.cc
+else
+ CORE_CC_EXCLUDE_SRCS += tensorflow/lite/minimal_logging_android.cc
+ CORE_CC_EXCLUDE_SRCS += tensorflow/lite/minimal_logging_ios.cc
+endif
+
# Filter out all the excluded files.
TF_LITE_CC_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(CORE_CC_ALL_SRCS))
diff --git a/tensorflow/lite/tools/optimize/subgraph_quantizer.cc b/tensorflow/lite/tools/optimize/subgraph_quantizer.cc
index 118e055..c1ff444 100644
--- a/tensorflow/lite/tools/optimize/subgraph_quantizer.cc
+++ b/tensorflow/lite/tools/optimize/subgraph_quantizer.cc
@@ -325,6 +325,27 @@
return kTfLiteOk;
}
+TfLiteStatus SubgraphQuantizer::AsymmetricQuantizeSoftmax(
+ BuiltinOperator op_code, OperatorT* op) {
+ TF_LITE_ENSURE_EQ(this->error_reporter_, op->inputs.size(), 1);
+ TF_LITE_ENSURE_EQ(this->error_reporter_, op->outputs.size(), 1);
+
+ if (IsSubgraphInput(op->inputs[0])) {
+ TF_LITE_ENSURE_STATUS(AsymmetricQuantizeTensor(op_code, op->inputs[0]));
+ }
+
+ auto output_tensor = subgraph_->tensors[op->outputs[0]].get();
+ if (output_tensor->type != TensorType_FLOAT32) {
+ return kTfLiteOk;
+ }
+
+ // Softmax output is hardcoded to have 1/256 as scale and -128 as zero point.
+ output_tensor->type = TensorType_INT8;
+ output_tensor->quantization->scale = {1.0f / 256.0f};
+ output_tensor->quantization->zero_point = {-128};
+ return kTfLiteOk;
+}
+
bool SubgraphQuantizer::IsSubgraphInput(int32_t tensor_idx) const {
return std::find(subgraph_->inputs.begin(), subgraph_->inputs.end(),
tensor_idx) != subgraph_->inputs.end();
@@ -342,8 +363,9 @@
case BuiltinOperator_MAX_POOL_2D:
return PropagateMinMaxForAvgAndMaxPool(op_code, op);
case BuiltinOperator_SQUEEZE:
- case BuiltinOperator_SOFTMAX:
return AsymmetricQuantizeSingleInputOutputOp(op_code, op);
+ case BuiltinOperator_SOFTMAX:
+ return AsymmetricQuantizeSoftmax(op_code, op);
default:
return kTfLiteError;
}
diff --git a/tensorflow/lite/tools/optimize/subgraph_quantizer.h b/tensorflow/lite/tools/optimize/subgraph_quantizer.h
index 9d6ca7f..fd1c392 100644
--- a/tensorflow/lite/tools/optimize/subgraph_quantizer.h
+++ b/tensorflow/lite/tools/optimize/subgraph_quantizer.h
@@ -51,6 +51,12 @@
TfLiteStatus AsymmetricQuantizeSingleInputOutputOp(BuiltinOperator op_code,
OperatorT* op);
+ // Asymmetric quantizes inputs and outputs of an Softmax Op.
+ // Input is quantized with the min-max range and output is hardcoded to have
+ // 1/256 as scale and -128 as zero point.
+ TfLiteStatus AsymmetricQuantizeSoftmax(BuiltinOperator op_code,
+ OperatorT* op);
+
TfLiteStatus AsymmetricQuantizeTensor(BuiltinOperator op_code,
int32_t tensor_idx);
diff --git a/tensorflow/lite/tools/optimize/subgraph_quantizer_test.cc b/tensorflow/lite/tools/optimize/subgraph_quantizer_test.cc
index 4b23ced..7261d22 100644
--- a/tensorflow/lite/tools/optimize/subgraph_quantizer_test.cc
+++ b/tensorflow/lite/tools/optimize/subgraph_quantizer_test.cc
@@ -291,6 +291,7 @@
ASSERT_EQ(op->outputs.size(), 1);
auto float_graph = readonly_model->subgraphs()->Get(0);
+ // Verify input.
ASSERT_EQ(float_graph->tensors()->Get(op->inputs[0])->type(),
TensorType_FLOAT32);
ASSERT_EQ(float_graph->tensors()->Get(op->outputs[0])->type(),
@@ -306,12 +307,18 @@
VerifyAsymmetricQuantizationScale(*float_input_quant_params,
*input_quant_params);
+ // Verify output.
auto float_output_quant_params =
float_graph->tensors()->Get(op->outputs[0])->quantization();
auto output_quant_params =
subgraph->tensors[op->outputs[0]]->quantization.get();
- VerifyAsymmetricQuantizationScale(*float_output_quant_params,
- *output_quant_params);
+ ASSERT_EQ(float_output_quant_params->min()->size(), 1);
+ ASSERT_EQ(float_output_quant_params->max()->size(), 1);
+
+ ASSERT_EQ(output_quant_params->scale.size(), 1);
+ ASSERT_EQ(output_quant_params->zero_point.size(), 1);
+ ASSERT_EQ(1.0f / 256.0f, output_quant_params->scale[0]);
+ ASSERT_EQ(-128, output_quant_params->zero_point[0]);
}
TEST(SubgraphQuantizerTest, VerifyAvgPoolQuantization) {
diff --git a/tensorflow/opensource_only.files b/tensorflow/opensource_only.files
index 0312114..9264939 100644
--- a/tensorflow/opensource_only.files
+++ b/tensorflow/opensource_only.files
@@ -120,8 +120,8 @@
tensorflow/third_party/eigen3/unsupported/Eigen/CXX11/ThreadPool
tensorflow/third_party/eigen3/unsupported/Eigen/SpecialFunctions
tensorflow/third_party/eigen3/unsupported/Eigen/MatrixFunctions
+tensorflow/third_party/eigen3/gpu_packet_math.patch
tensorflow/third_party/eigen3/LICENSE
-tensorflow/third_party/eigen3/gebp_neon.patch
tensorflow/third_party/eigen3/BUILD
tensorflow/third_party/systemlibs/build_defs.bzl.tpl
tensorflow/third_party/systemlibs/absl_py.BUILD
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 98b9518..ca6a09c 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -1195,6 +1195,7 @@
srcs = ["framework/registry_test.py"],
additional_deps = [
":framework_for_generated_wrappers",
+ "@absl_py//absl/testing:parameterized",
"//tensorflow/python:client_testlib",
],
main = "framework/registry_test.py",
@@ -2265,7 +2266,7 @@
":function_def_to_graph",
":functional_ops_gen",
":gradients",
- ":gradients_impl",
+ ":gradients_util",
":graph_to_function_def",
":pywrap_tensorflow",
":util",
@@ -2278,6 +2279,7 @@
name = "while_v2",
srcs = [
"ops/while_v2.py",
+ "ops/while_v2_indexed_slices_rewriter.py",
],
srcs_version = "PY2AND3",
deps = [
@@ -2290,7 +2292,7 @@
":framework_ops",
":function_def_to_graph",
":functional_ops_gen",
- ":gradients_impl",
+ ":gradients_util",
":list_ops",
":tensor_array_ops",
":tensor_shape",
@@ -2381,6 +2383,7 @@
srcs_version = "PY2AND3",
deps = [
":gradients_impl",
+ ":gradients_util",
":unconnected_gradients",
"//tensorflow/python/eager:function",
"//tensorflow/python/eager:tape",
@@ -2404,7 +2407,6 @@
":framework",
":framework_for_generated_wrappers",
":framework_ops",
- ":functional_ops",
":image_grad",
":linalg_grad",
":linalg_ops",
@@ -2416,15 +2418,34 @@
":optional_grad",
":platform",
":random_grad",
- ":resource_variable_ops",
":tensor_array_ops",
+ ":unconnected_gradients",
+ ":util",
+ ],
+)
+
+py_library(
+ name = "gradients_util",
+ srcs = [
+ "ops/gradients_util.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":array_ops",
+ ":control_flow_ops",
+ ":control_flow_util",
+ ":framework",
+ ":framework_for_generated_wrappers",
+ ":framework_ops",
+ ":functional_ops",
+ ":math_ops",
+ ":platform",
+ ":resource_variable_ops",
":tensor_util",
":unconnected_gradients",
":util",
- ":variable_scope",
"//tensorflow/core:protos_all_py",
"//tensorflow/python/eager:context",
- "//tensorflow/python/eager:tape",
"//third_party/py/numpy",
"@six_archive//:six",
],
diff --git a/tensorflow/python/autograph/converters/call_trees.py b/tensorflow/python/autograph/converters/call_trees.py
index 7026a16..04439ba 100644
--- a/tensorflow/python/autograph/converters/call_trees.py
+++ b/tensorflow/python/autograph/converters/call_trees.py
@@ -22,231 +22,46 @@
from __future__ import division
from __future__ import print_function
-import collections
-
import gast
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.pyct import anno
-from tensorflow.python.autograph.pyct import ast_util
-from tensorflow.python.autograph.pyct import inspect_utils
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import templates
-from tensorflow.python.util import tf_inspect
-class FunctionInfo(collections.namedtuple('FunctionInfo', ('dtype',))):
- pass
-
-
-# TODO(mdan): Move this to a separate transformer.
-KNOWN_NUMPY_FUNCTIONS = {
- ('numpy', 'random', 'binomial'): FunctionInfo(dtype='tf.int64'),
-}
-
-
-# TODO(mdan): Get rid of these interfaces. Can now depend directly on Namer.
-
-
-class FunctionNamer(object):
- """Describes the interface for CallTreeTransformer's namer."""
-
- def compiled_function_name(self,
- original_fqn,
- live_entity=None,
- owner_type=None):
- """Generate the name corresponding to the compiled version of a function.
-
- Args:
- original_fqn: string or tuple(string)
- live_entity: Callable, the actual target function, if known.
- owner_type: Optional object. If present, it indicates that the function is
- a member of the given type.
- Returns:
- string, bool
- """
- raise NotImplementedError()
-
- def compiled_class_name(self, original_fqn, live_entity=None):
- """Generate the name corresponding to the compiled version of a class.
-
- Args:
- original_fqn: string or tuple(string)
- live_entity: The actual target class, if known.
- Returns:
- string
- """
- raise NotImplementedError()
-
-
-# TODO(mdan): Rename to CallsTransformer.
+# TODO(mdan): Rename to FunctionCallsTransformer.
class CallTreeTransformer(converter.Base):
"""Transforms the call tree by renaming transformed symbols."""
- def _resolve_decorator_name(self, node):
- """Used to resolve decorator info."""
- if isinstance(node, gast.Call):
- return self._resolve_decorator_name(node.func)
- if isinstance(node, gast.Name):
- # TODO(mdan): Add test coverage for this branch.
- return self.ctx.info.namespace.get(node.id)
- if isinstance(node, gast.Attribute):
- parent = self._resolve_decorator_name(node.value)
- if parent is not None:
- return getattr(parent, node.attr)
- return None
- raise ValueError(node)
-
- def _try_resolve_target(self, node):
- """Works for methods of objects of known type."""
- if anno.hasanno(node, 'live_val'):
- return anno.getanno(node, 'live_val')
- if isinstance(node, gast.Attribute) and anno.hasanno(node, 'type'):
- owner_type = anno.getanno(node, 'type')
- if hasattr(owner_type, node.attr):
- return getattr(owner_type, node.attr)
- else:
- # TODO(mdan): We should probably return None here rather than an error.
- raise ValueError('Type "%s" has no attribute "%s". Is it dynamic?' %
- (owner_type, node.attr))
- return None
-
- def _function_is_compilable(self, target_entity):
- """Determines whether an entity can be compiled at all."""
- # TODO(mdan): Expand.
-
- if target_entity.__module__ is None:
- # Functions like builtins and NumPy don't expose a module.
- # Those in general should not be compiled.
- return False
-
- if inspect_utils.isbuiltin(target_entity):
- return False
-
- if inspect_utils.isnamedtuple(target_entity):
- # namedtuple doesn't expose its source code, making it uncompilable.
- return False
-
- return True
-
- def _should_compile(self, node, fqn):
- """Determines whether an entity should be compiled in the context."""
- # TODO(mdan): Needs cleanup. We should remove the use of fqn altogether.
- module_name = fqn[0]
- for mod in self.ctx.program.uncompiled_modules:
- if module_name.startswith(mod[0] + '.'):
- return False
-
- for i in range(1, len(fqn)):
- if fqn[:i] in self.ctx.program.uncompiled_modules:
- return False
-
- target_entity = self._try_resolve_target(node.func)
-
- if target_entity is not None:
-
- # Currently, lambdas are always converted.
- # TODO(mdan): Allow markers of the kind f = ag.do_not_convert(lambda: ...)
- if inspect_utils.islambda(target_entity):
- return True
-
- # This may be reached when "calling" a callable attribute of an object.
- # For example:
- #
- # self.fc = tf.keras.layers.Dense()
- # self.fc()
- #
- for mod in self.ctx.program.uncompiled_modules:
- if target_entity.__module__.startswith(mod[0] + '.'):
- return False
-
- # Inspect the target function decorators. If any include a @convert
- # or @do_not_convert annotation, then they must be called as they are.
- # TODO(mdan): This may be quite heavy. Perhaps always dynamically convert?
- # To parse and re-analyze each function for every call site could be quite
- # wasteful. Maybe we could cache the parsed AST?
- try:
- target_node, _ = parser.parse_entity(target_entity)
- target_node = target_node.body[0]
- except TypeError:
- # Functions whose source we cannot access are compilable (e.g. wrapped
- # to py_func).
- return True
-
- # This attribute is set when the decorator was applied before the
- # function was parsed. See api.py.
- if hasattr(target_entity, '__ag_compiled'):
- return False
-
- for dec in target_node.decorator_list:
- decorator_fn = self._resolve_decorator_name(dec)
- if (decorator_fn is not None and
- self.ctx.program.options.should_strip(decorator_fn)):
- return False
-
- return True
-
- def _rename_compilable_function(self, node):
- assert anno.hasanno(node.func, 'live_val')
- assert anno.hasanno(node.func, 'fqn')
- target_entity = anno.getanno(node.func, 'live_val')
- target_fqn = anno.getanno(node.func, 'fqn')
-
- if anno.hasanno(node, 'is_constructor'):
- new_name = self.ctx.namer.compiled_class_name(
- target_fqn, live_entity=target_entity)
- do_rename = True
- else:
- if anno.hasanno(node.func, 'parent_type'):
- owner_type = anno.getanno(node.func, 'parent_type')
- else:
- # Fallback - not reliable.
- owner_type = inspect_utils.getmethodclass(target_entity)
- new_name, do_rename = self.ctx.namer.compiled_function_name(
- target_fqn, live_entity=target_entity, owner_type=owner_type)
-
- if do_rename:
- if target_entity is not None:
- if tf_inspect.ismethod(target_entity):
- # The renaming process will transform it into a regular function.
- # TODO(mdan): Is this complete? How does it work with nested members?
- node.args = [node.func.value] + node.args
- node.func = templates.replace_as_expression(
- 'func_name', func_name=new_name)
+ def visit_FunctionDef(self, node):
+ node.args = self.visit(node.args)
+ node.body = self.visit_block(node.body)
+ # TODO(mdan): Is this correct for local functions?
+ node.decorator_list = []
+ if node.returns:
+ node.returns = self.visit(node.returns)
return node
- def _wrap_to_py_func_single_return(self, node, dtype):
- # TODO(mdan): Properly handle varargs, etc.
- template = """
- ag__.utils.wrap_py_func(func, dtype, (args,), kwargs, False)
- """
- return templates.replace_as_expression(
- template,
- func=node.func,
- dtype=parser.parse_expression(dtype),
- args=node.args,
- kwargs=ast_util.keywords_to_dict(node.keywords))
+ def visit_With(self, node):
+ # Context manager calls (in node.items) are not converted.
+ node.body = self.visit_block(node.body)
+ return node
- def _insert_dynamic_conversion(self, node):
- """Inlines a dynamic conversion for a dynamic function."""
- # TODO(mdan): Pass information on the statically compiled functions.
- # Having access to the statically compiled functions can help avoid
- # unnecessary compilation.
- # For example, this would lead to function `a` being compiled twice:
- #
- # def a():
- # v = b
- # b()
- # def b():
- # a()
- #
- # This is really a problem with recursive calls, which currently can
- # only be gated by a static condition, and should be rare.
- # TODO(mdan): It probably makes sense to use dynamic conversion every time.
- # Before we could convert all the time though, we'd need a reasonable
- # caching mechanism.
+ def visit_Call(self, node):
+ # TODO(mdan): Refactor converted_call as a 'Call' operator.
+
+ # Calls to the internal 'ag__' module are never converted (though their
+ # arguments might be).
+ full_name = str(anno.getanno(node.func, anno.Basic.QN, default=''))
+ if full_name.startswith('ag__.'):
+ return self.generic_visit(node)
+ if (full_name == 'print' and
+ not self.ctx.program.options.uses(converter.Feature.BUILTIN_FUNCTIONS)):
+ return self.generic_visit(node)
+
template = """
ag__.converted_call(func, owner, options, args)
"""
@@ -256,6 +71,7 @@
else:
func = node.func
owner = parser.parse_expression('None')
+
new_call = templates.replace_as_expression(
template,
func=func,
@@ -266,68 +82,9 @@
args=node.args)
# TODO(mdan): Improve the template mechanism to better support this.
new_call.keywords = node.keywords
+
return new_call
- def visit_FunctionDef(self, node):
- node.args = self.visit(node.args)
- node.body = self.visit_block(node.body)
- node.decorator_list = []
- node.returns = self.visit_block(node.returns)
- return node
-
- def visit_Call(self, node):
- if anno.hasanno(node.func, 'live_val'):
- target_entity = anno.getanno(node.func, 'live_val')
-
- if anno.hasanno(node.func, 'fqn'):
- target_fqn = anno.getanno(node.func, 'fqn')
- else:
- target_fqn = None
-
- if self._function_is_compilable(target_entity):
- if self._should_compile(node, target_fqn):
- node = self._rename_compilable_function(node)
- else:
- node = self.generic_visit(node)
- return node
-
- elif target_fqn and target_fqn in KNOWN_NUMPY_FUNCTIONS:
- # TODO(mdan): Should we replace these with equivalent TF ops instead?
- node = self._wrap_to_py_func_single_return(
- node, KNOWN_NUMPY_FUNCTIONS[target_fqn].dtype)
-
- elif inspect_utils.isbuiltin(target_entity):
- # Note: Any builtin that passed the builtins converter is assumed to be
- # safe for graph mode.
- return node
-
- elif inspect_utils.isnamedtuple(target_entity):
- # Although not compilable, we assume they are safe for graph mode.
- node = self.generic_visit(node)
- return node
-
- else:
- # TODO(mdan): Instert dynamic conversion here instead.
- raise NotImplementedError(
- 'py_func with return values (unknown function)')
- else:
- # Special cases
- # TODO(mdan): These need a systematic review - there may be more.
-
- # 1. super() calls - these are preserved. The class conversion mechanism
- # will ensure that they return the correct value.
- if ast_util.matches(node, parser.parse_expression('super(_)')):
- return node
-
- # 2. super().method calls - these are preserved as well, when the
- # conversion processes the entire class.
- if (ast_util.matches(node, parser.parse_expression('super(_)._(_)')) and
- self.ctx.info.owner_type is not None):
- return node
-
- node = self._insert_dynamic_conversion(node)
- return node
-
def transform(node, ctx):
"""Transform function call to the compiled counterparts.
diff --git a/tensorflow/python/autograph/converters/call_trees_test.py b/tensorflow/python/autograph/converters/call_trees_test.py
index 454d75d..6ee56bf 100644
--- a/tensorflow/python/autograph/converters/call_trees_test.py
+++ b/tensorflow/python/autograph/converters/call_trees_test.py
@@ -18,147 +18,49 @@
from __future__ import division
from __future__ import print_function
-import collections
-
-import numpy as np
-
from tensorflow.python.autograph.converters import call_trees
from tensorflow.python.autograph.core import converter_testing
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
class CallTreesTest(converter_testing.TestCase):
- def test_basic(self):
+ def test_normal_function(self):
- def test_fn_1(_):
- raise ValueError('This should not be called in the compiled version.')
-
- def other_test_fn_1(a):
- return a + 1
-
- def test_fn_2(a):
- return test_fn_1(a) + 1
-
- ns = {'test_fn_1': test_fn_1}
- node, ctx = self.prepare(test_fn_2, ns)
- node = call_trees.transform(node, ctx)
-
- with self.compiled(node, ns) as result:
- new_name, _ = ctx.namer.compiled_function_name(('test_fn_1',))
- setattr(result, new_name, other_test_fn_1)
- self.assertEquals(result.test_fn_2(1), 3)
-
- def test_dynamic_function(self):
-
- def test_fn_1():
- raise ValueError('This should be masked by the mock in self.compiled.')
-
- def test_fn_2(f):
+ def test_fn(f):
return f() + 3
- with self.converted(test_fn_2, call_trees, {}) as result:
- # 10 = 7 (from the mock) + 3 (from test_fn_2)
- self.assertEquals(10, result.test_fn_2(test_fn_1))
+ with self.converted(test_fn, call_trees, {}) as result:
+ self.assertEquals(
+ result.test_fn(None),
+ converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 3)
+ self.assertListEqual(self.dynamic_calls, [()])
- def test_basic_method(self):
+ def test_class_method(self):
class TestClass(object):
- def test_fn_1(self, a):
- return a + 1
+ def test_method(self, a):
+ return self.other_method(a) + 1
- def test_fn_2(self, a):
- return self.test_fn_1(a) + 1
+ tc = TestClass()
+ with self.converted(TestClass.test_method, call_trees, {}) as result:
+ self.assertEquals(converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 1,
+ result.test_method(tc, 1))
+ self.assertListEqual(self.dynamic_calls, [(1,)])
- ns = {'TestClass': TestClass}
- node, ctx = self.prepare(
- TestClass.test_fn_2,
- ns,
- namer=converter_testing.FakeNoRenameNamer(),
- arg_types={'self': (TestClass.__name__, TestClass)})
- node = call_trees.transform(node, ctx)
+ def test_object_method(self):
- with self.compiled(node, ns) as result:
- tc = TestClass()
- self.assertEquals(3, result.test_fn_2(tc, 1))
+ class TestClass(object):
- def test_known_called_lambda(self):
+ def test_method(self, a):
+ return self.other_method(a) + 1
- l = lambda x: x
-
- def test_fn(a):
- return l(a)
-
- ns = {'l': l}
- node, ctx = self.prepare(test_fn, ns)
- node = call_trees.transform(node, ctx)
-
- with self.compiled(node, ns) as result:
- self.assertEquals(1, result.test_fn(1))
-
- def test_known_called_namedtuple(self):
-
- nt = collections.namedtuple('TestNamedTuple', ['a'])
-
- def test_fn(a):
- return nt(a)
-
- ns = {'nt': nt}
- node, ctx = self.prepare(test_fn, ns)
- node = call_trees.transform(node, ctx)
-
- with self.compiled(node, ns) as result:
- self.assertEquals(nt(1), result.test_fn(1))
-
- def test_py_func_known_function(self):
-
- def test_fn():
- return np.random.binomial(2, 0.5)
-
- with self.converted(test_fn, call_trees, {'np': np},
- dtypes.int64) as result:
- with self.cached_session() as sess:
- self.assertTrue(isinstance(result.test_fn(), ops.Tensor))
- self.assertIn(self.evaluate(result.test_fn()), (0, 1, 2))
-
- def test_uncompiled_modules(self):
-
- def test_fn(a):
- a = math_ops.multiply(a, constant_op.constant(2))
- a = math_ops.add(a, constant_op.constant(1))
- return a
-
- ns = {'math_ops': math_ops, 'constant_op': constant_op}
- node, ctx = self.prepare(
- test_fn,
- ns,
- arg_types=set(((math_ops.__name__,), (constant_op.__name__,))))
- node = call_trees.transform(node, ctx)
-
- with self.compiled(node, ns) as result:
- with self.cached_session() as sess:
- result_tensor = result.test_fn(constant_op.constant(1))
- self.assertEquals(self.evaluate(result_tensor), 3)
-
- def test_call_to_decorated_function(self):
-
- def decorator(f):
- return f
-
- @decorator
- def called_fn(a):
- return a
-
- def test_fn(a):
- return called_fn(a)
-
- node, ctx = self.prepare(test_fn, {'called_fn': called_fn})
- node = call_trees.transform(node, ctx)
+ tc = TestClass()
+ with self.converted(tc.test_method, call_trees, {}) as result:
+ self.assertEquals(converter_testing.RESULT_OF_MOCK_CONVERTED_CALL + 1,
+ result.test_method(tc, 1))
+ self.assertListEqual(self.dynamic_calls, [(1,)])
if __name__ == '__main__':
diff --git a/tensorflow/python/autograph/converters/continue_statements.py b/tensorflow/python/autograph/converters/continue_statements.py
index c3b6679..725e053 100644
--- a/tensorflow/python/autograph/converters/continue_statements.py
+++ b/tensorflow/python/autograph/converters/continue_statements.py
@@ -121,6 +121,28 @@
node.orelse = self.visit_block(node.orelse)
return node
+ def visit_With(self, node):
+ node.items = self.visit_block(node.items)
+ node.body = self.visit_block(node.body,
+ after_visit=self._postprocess_statement)
+ return node
+
+ def visit_Try(self, node):
+ node.body = self.visit_block(node.body,
+ after_visit=self._postprocess_statement)
+ node.orelse = self.visit_block(node.orelse,
+ after_visit=self._postprocess_statement)
+ # In Python 3.8 and later continue is allowed in finally blocks
+ node.finalbody = self.visit_block(node.finalbody,
+ after_visit=self._postprocess_statement)
+ node.handlers = self.visit_block(node.handlers)
+ return node
+
+ def visit_ExceptHandler(self, node):
+ node.body = self.visit_block(node.body,
+ after_visit=self._postprocess_statement)
+ return node
+
def transform(node, ctx):
transformer = ContinueCanonicalizationTransformer(ctx)
diff --git a/tensorflow/python/autograph/converters/control_flow.py b/tensorflow/python/autograph/converters/control_flow.py
index 0416820..2526d30 100644
--- a/tensorflow/python/autograph/converters/control_flow.py
+++ b/tensorflow/python/autograph/converters/control_flow.py
@@ -278,26 +278,15 @@
# the loop state, regardless of whether they are later used or not.
loop_state = body_scope.modified & live_in
- undefined_lives = loop_state - defined_in
+ # Variable that are used or defined inside the loop, but not defined
+ # before entering the loop
+ undefined_lives = ((loop_state - defined_in)
+ | ((body_scope.modified - live_in) & live_out))
# Only simple variables must be defined. The composite ones will be
# implicitly checked at runtime.
undefined_simple_lives = {v for v in undefined_lives if v.is_simple()}
- if undefined_simple_lives:
- raise NameError(
- 'cannot convert loop: it includes symbols that are undefined'
- ' when entering the loop: {}'.format(
- self._fmt_symbols(undefined_simple_lives)))
- live_defs_in_loop = (body_scope.modified - live_in) & live_out
- if live_defs_in_loop:
- # TODO(mdan): Include reference to explanation why.
- raise NotImplementedError(
- 'cannot convert loop: it includes symbols that are defined'
- ' inside the loop, but used later: {}. To fix, initialize'
- ' these symbols before the loop'.format(
- self._fmt_symbols(live_defs_in_loop)))
-
- return loop_state, reserved_symbols
+ return loop_state, reserved_symbols, undefined_simple_lives
def _state_constructs(self, loop_state, reserved_symbols):
loop_state = list(loop_state)
@@ -321,7 +310,7 @@
def visit_While(self, node):
self.generic_visit(node)
- loop_state, reserved_symbols = self._get_loop_state(node)
+ loop_state, reserved_symbols, possibly_undef = self._get_loop_state(node)
# Note: one might expect we can dispatch based on the loop condition.
# But because that is dependent on the state, it cannot be evaluated ahead
@@ -386,12 +375,13 @@
extra_deps=tuple(s.ast() for s in cond_closure),
)
- return node
+ undefined_assigns = self._create_undefined_assigns(possibly_undef)
+ return undefined_assigns + node
def visit_For(self, node):
self.generic_visit(node)
- loop_state, reserved_symbols = self._get_loop_state(node)
+ loop_state, reserved_symbols, possibly_undef = self._get_loop_state(node)
loop_state, state_ssf, state_ast_tuple, ssf_map = self._state_constructs(
loop_state, reserved_symbols)
node_body = ast_util.rename_symbols(node.body, ssf_map)
@@ -446,7 +436,8 @@
body_name=self.ctx.namer.new_symbol('loop_body', reserved_symbols),
body=node_body)
- return node
+ undefined_assigns = self._create_undefined_assigns(possibly_undef)
+ return undefined_assigns + node
def transform(node, ctx):
diff --git a/tensorflow/python/autograph/converters/control_flow_test.py b/tensorflow/python/autograph/converters/control_flow_test.py
index 1e73dca..d3accd3 100644
--- a/tensorflow/python/autograph/converters/control_flow_test.py
+++ b/tensorflow/python/autograph/converters/control_flow_test.py
@@ -78,17 +78,6 @@
self.assertTransformedResult(test_fn, constant_op.constant(5), 0)
- def test_while_variable_defined_in_body(self):
- def bad_while_loop(n):
- while n > 0:
- n -= 1
- s = n
- return s
-
- node, ctx = self.prepare(bad_while_loop, {})
- with self.assertRaises(NameError):
- control_flow.transform(node, ctx)
-
@test_util.run_deprecated_v1
def test_if_basic(self):
@@ -225,16 +214,6 @@
self.assertEqual(result.test_fn(5), 10)
self.assertEqual(eval_count[0], 1)
- def test_for_variable_defined_in_body(self):
- def bad_for_loop(n):
- for i in range(n):
- s = i
- return s
-
- node, ctx = self.prepare(bad_for_loop, {})
- with self.assertRaises(NameError):
- control_flow.transform(node, ctx)
-
@test_util.run_deprecated_v1
def test_for_tuple_unpacking(self):
def test_fn(x_list):
diff --git a/tensorflow/python/autograph/converters/return_statements.py b/tensorflow/python/autograph/converters/return_statements.py
index c6c6c3b..723acab 100644
--- a/tensorflow/python/autograph/converters/return_statements.py
+++ b/tensorflow/python/autograph/converters/return_statements.py
@@ -108,6 +108,28 @@
node.orelse, _ = self._visit_statement_block(node, node.orelse)
return node
+ def visit_With(self, node):
+ node.items = self.visit_block(node.items)
+ node.body, definitely_returns = self._visit_statement_block(node, node.body)
+ if definitely_returns:
+ anno.setanno(node, STMT_DEFINITELY_RETURNS, True)
+ return node
+
+ def visit_Try(self, node):
+ # We could decide whether a 'try' DEFINITELY_RETURNS based on its components
+ # It is not clear whether we want to do anything with this given
+ # a 'try' is likely to throw an exception in some circumstances.
+ node.body, _ = self._visit_statement_block(node, node.body)
+ node.orelse, _ = self._visit_statement_block(node, node.orelse)
+ node.finalbody, _ = self._visit_statement_block(node, node.finalbody)
+ node.handlers = self.visit_block(node.handlers)
+ return node
+
+ def visit_ExceptHandler(self, node):
+ # To determine whether `try` DEFINITELY_RETURNS we need to revisit this.
+ node.body, _ = self._visit_statement_block(node, node.body)
+ return node
+
def visit_If(self, node):
node.test = self.visit(node.test)
@@ -305,6 +327,22 @@
node.orelse = self._visit_statement_block(node, node.orelse)
return node
+ def visit_With(self, node):
+ node.items = self.visit_block(node.items)
+ node.body = self._visit_statement_block(node, node.body)
+ return node
+
+ def visit_Try(self, node):
+ node.body = self._visit_statement_block(node, node.body)
+ node.orelse = self._visit_statement_block(node, node.orelse)
+ node.finalbody = self._visit_statement_block(node, node.finalbody)
+ node.handlers = self.visit_block(node.handlers)
+ return node
+
+ def visit_ExceptHandler(self, node):
+ node.body = self._visit_statement_block(node, node.body)
+ return node
+
def visit_If(self, node):
node.test = self.visit(node.test)
node.body = self._visit_statement_block(node, node.body)
@@ -370,5 +408,6 @@
transformer = ReturnStatementsTransformer(
ctx, default_to_null_return=default_to_null_return)
node = transformer.visit(node)
+ transformer.debug_print_src(node)
return node
diff --git a/tensorflow/python/autograph/converters/side_effect_guards.py b/tensorflow/python/autograph/converters/side_effect_guards.py
index d7c0951..7e556d9 100644
--- a/tensorflow/python/autograph/converters/side_effect_guards.py
+++ b/tensorflow/python/autograph/converters/side_effect_guards.py
@@ -125,6 +125,10 @@
node.orelse = self._visit_and_reindent(node.orelse)
return node
+ # TODO(b/123995141) Remove once ExceptionHandlers are in the CFG
+ def visit_ExceptHandler(self, node):
+ return node
+
def visit_Expr(self, node):
self.generic_visit(node)
if isinstance(node.value, gast.Call):
diff --git a/tensorflow/python/autograph/core/config.py b/tensorflow/python/autograph/core/config.py
index 574f819..5dce3e6 100644
--- a/tensorflow/python/autograph/core/config.py
+++ b/tensorflow/python/autograph/core/config.py
@@ -28,21 +28,16 @@
'float': float,
}
-DEFAULT_UNCOMPILED_MODULES = set((
- ('tensorflow',),
- (utils.__name__,),
- # All of tensorflow's subpackages. Unlike the root tf module, they don't
- # have well-known names. Not referring to the module directly to avoid
- # circular imports.
- (
- utils.__name__[:-len('.python.autograph.utils')],),
-))
+def internal_module_name(name):
+ full_name = utils.__name__
+ name_start = full_name.find(name)
+ name_end = name_start + len(name) + 1
+ return full_name[:name_end]
-NO_SIDE_EFFECT_CONSTRUCTORS = set(('tensorflow',))
-# TODO(mdan): Also allow controlling the generated names.
-# TODO(mdan); Consolidate all internal imports into a single __ag module.
+DEFAULT_UNCOMPILED_MODULES = set(((internal_module_name('tensorflow'),),))
+
COMPILED_IMPORT_STATEMENTS = (
'from __future__ import print_function',
)
diff --git a/tensorflow/python/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py
index 3ad6968..abe0f25 100644
--- a/tensorflow/python/autograph/impl/api.py
+++ b/tensorflow/python/autograph/impl/api.py
@@ -18,7 +18,10 @@
from __future__ import division
from __future__ import print_function
+import collections
+import copy
import functools
+import pdb
import sys
from enum import Enum
@@ -35,9 +38,9 @@
from tensorflow.python.autograph.pyct import compiler
from tensorflow.python.autograph.pyct import errors
from tensorflow.python.autograph.pyct import inspect_utils
+from tensorflow.python.autograph.utils import ag_logging as logging
from tensorflow.python.autograph.utils import py_func
from tensorflow.python.framework import tensor_util
-from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
@@ -161,7 +164,9 @@
def converted_call(f, owner, options, *args, **kwargs):
"""Compiles a function call inline. For internal use only."""
- logging.vlog(logging.DEBUG, 'Converted call: %s; owner: %s', f, owner)
+ logging.log(1,
+ 'Converted call: %s; owner: %s\n args: %s\n kwargs: %s\n',
+ f, owner, args, kwargs)
if owner is not None:
if not isinstance(f, str):
@@ -180,8 +185,25 @@
if inspect_utils.isbuiltin(f):
return py_builtins.overload_of(f)(*args, **kwargs)
+ # TODO(b/122265385): Remove this bypass.
+ if ('wrapt' in sys.modules and
+ hasattr(sys.modules['wrapt'], 'FunctionWrapper') and
+ isinstance(f, sys.modules['wrapt'].FunctionWrapper)):
+ logging.warn(
+ 'Entity {} appears to be decorated by wrapt, which is not yet supported'
+ ' by AutoGraph. The function will be called without transformation.'
+ ' You may however apply AutoGraph before the decorator.'.format(f), 1)
+ logging.log(2, 'Permanently whitelisted: %s: wrapt decorated', f)
+ return f(*args, **kwargs)
+
+ # Other built-in modules are permanently whitelisted.
+ # TODO(mdan): Figure out how to do this consistently for all stdlib modules.
+ if (f in collections.__dict__.values() or f in pdb.__dict__.values() or
+ f in copy.__dict__.values()):
+ logging.log(2, 'Permanently whitelisted: %s: part of builtin module', f)
+ return f(*args, **kwargs)
+
# TODO(mdan): This needs cleanup.
- # In particular, we may want to avoid renaming functions altogether.
if not options.force_conversion and conversion.is_whitelisted_for_graph(f):
# TODO(mdan): This may be inconsistent in certain situations.
@@ -207,91 +229,118 @@
if not options.internal_convert_user_code:
return f(*args, **kwargs)
- # Unwrap functools.partial objects
- # TODO(mdan): Consider sharing unwrapping logic with tf_inspect.
- while isinstance(f, functools.partial):
- args = f.args + args
- new_kwargs = {}
- if f.keywords is not None:
- new_kwargs.update(f.keywords)
- new_kwargs.update(kwargs)
- kwargs = new_kwargs
- f = f.func
+ # TODO(mdan): Move this entire block inside to_graph.
+ try: # Begin of transformation error guards
- if tf_inspect.isfunction(f) or tf_inspect.ismethod(f):
- # Regular functions
- target_entity = f
- arg_map_target = f
- f_self = inspect_utils.getmethodself(f)
+ # Unwrap functools.partial objects
+ # TODO(mdan): Consider sharing unwrapping logic with tf_inspect.
+ while isinstance(f, functools.partial):
+ args = f.args + args
+ new_kwargs = {}
+ if f.keywords is not None:
+ new_kwargs.update(f.keywords)
+ new_kwargs.update(kwargs)
+ kwargs = new_kwargs
+ f = f.func
- # TODO(b/119246461): This may be more elegantly handled using __get__?
- if f_self is not None:
- # If this is a method call, it may or may not include self.
- #
- # Example when self is included:
- # converted_call(to_graph(foo.bar), foo)
- #
- # Example when self is not included:
- # super(...).foo(args)
- #
- if owner is not None and (not args or args[0] is not owner):
- effective_args = (owner,) + args
- else:
- # When the owner is not specified, use the result of
- # inspect_utils.getmethodclass.
- # TODO(b/119246461): Make sure an owner is always specified.
- if not args or args[0] is not f_self:
- effective_args = (f_self,) + args
+ if tf_inspect.isfunction(f) or tf_inspect.ismethod(f):
+ # Regular functions
+ target_entity = f
+ arg_map_target = f
+ f_self = inspect_utils.getmethodself(f)
+
+ # TODO(b/119246461): This may be more elegantly handled using __get__?
+ if f_self is not None:
+ # If this is a method call, it may or may not include self.
+ #
+ # Example when self is included:
+ # converted_call(to_graph(foo.bar), foo)
+ #
+ # Example when self is not included:
+ # super(...).foo(args)
+ #
+ if owner is not None and (not args or args[0] is not owner):
+ effective_args = (owner,) + args
else:
- effective_args = (f_self,) + args[1:]
- partial_types = (f_self,)
- else:
+ # When the owner is not specified, use the result of
+ # inspect_utils.getmethodclass.
+ # TODO(b/119246461): Make sure an owner is always specified.
+ if not args or args[0] is not f_self:
+ effective_args = (f_self,) + args
+ else:
+ effective_args = (f_self,) + args[1:]
+ partial_types = (f_self,)
+ else:
+ effective_args = args
+ partial_types = ()
+
+ elif tf_inspect.isclass(f):
+ # Constructors
+ target_entity = f
+ arg_map_target = f.__init__
effective_args = args
partial_types = ()
- elif tf_inspect.isclass(f):
- # Constructors
- target_entity = f
- arg_map_target = f.__init__
- effective_args = args
- partial_types = ()
+ elif hasattr(f, '__call__') and hasattr(f, '__class__'):
+ # Callable objects
+ target_entity = f.__call__
+ arg_map_target = f.__call__
+ effective_args = (f,) + args
+ partial_types = (f.__class__,)
- elif hasattr(f, '__call__') and hasattr(f, '__class__'):
- # Callable objects
- target_entity = f.__call__
- arg_map_target = f.__call__
- effective_args = (f,) + args
- partial_types = (f.__class__,)
+ else:
+ raise NotImplementedError('unknown callable type "%s"' % type(f))
- else:
- raise NotImplementedError('unknown callable type "%s"' % type(f))
+ arg_values = tf_inspect.getcallargs(arg_map_target, *args, **kwargs)
+ arg_types = {}
+ for name, arg in arg_values.items():
+ arg_class = arg.__class__
+ arg_types[name] = (arg_class.__name__, arg_class)
- arg_values = tf_inspect.getcallargs(arg_map_target, *args, **kwargs)
- arg_types = {}
- for name, arg in arg_values.items():
- arg_class = arg.__class__
- arg_types[name] = (arg_class.__name__, arg_class)
+ # When called from within a decorator, this is the only indication that
+ # the function is a method - it appears that the decorator is applied
+ # before the method is bound.
+ if not partial_types:
+ if 'self' in arg_values:
+ if tf_inspect.isclass(arg_values['self'].__class__):
+ partial_types = (arg_values['self'].__class__,)
+ elif 'cls' in arg_values:
+ if tf_inspect.isclass(arg_values['cls']):
+ partial_types = (arg_values['cls'],)
- # When called from within a decorator, this is the only indication that
- # the function is a method - it appears that the decorator is applied
- # before the method is bound.
- if not partial_types:
- if 'self' in arg_values:
- if tf_inspect.isclass(arg_values['self'].__class__):
- partial_types = (arg_values['self'].__class__,)
- elif 'cls' in arg_values:
- if tf_inspect.isclass(arg_values['cls']):
- partial_types = (arg_values['cls'],)
+ logging.log(3, 'Partial types in conversion of %s: %s', target_entity,
+ partial_types)
- converted_f = to_graph(
- target_entity,
- recursive=options.recursive,
- arg_values=arg_values,
- arg_types=arg_types,
- experimental_optional_features=options.optional_features,
- experimental_strip_decorators=options.strip_decorators,
- experimental_verbose=options.verbose,
- experimental_partial_types=partial_types)
+ converted_f = to_graph(
+ target_entity,
+ recursive=options.recursive,
+ arg_values=arg_values,
+ arg_types=arg_types,
+ experimental_optional_features=options.optional_features,
+ experimental_strip_decorators=options.strip_decorators,
+ experimental_verbose=options.verbose,
+ experimental_partial_types=partial_types)
+
+ if logging.has_verbosity(2):
+ logging.log(2, 'Defaults of %s : %s', converted_f,
+ converted_f.__defaults__)
+ callargs = tf_inspect.getcallargs(converted_f, *effective_args, **kwargs)
+ formatted_callargs = '\n'.join(
+ ' {}: {}'.format(k, v) for k, v in callargs.items())
+ logging.log(2, 'Calling %s with\n%s\n', converted_f, formatted_callargs)
+
+ # TODO(mdan): Reduce this list.
+ except (errors.AutoGraphError, AssertionError, AttributeError, IndexError,
+ KeyError, NameError, NotImplementedError, SyntaxError, TypeError,
+ ValueError, IOError) as e:
+ logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True)
+ logging.warn(
+ 'Entity %s could not be transformed and will be staged without change.'
+ ' Error details can be found in the logs when running with the env'
+ ' variable AUTOGRAPH_VERBOSITY=5. Please report this to the AutoGraph'
+ ' team. Cause: %s', target_entity, e)
+
+ return f(*args, **kwargs)
result = converted_f(*effective_args, **kwargs)
@@ -436,8 +485,15 @@
compiled_module.__dict__[key] = val
compiled = getattr(compiled_module, name)
- if tf_inspect.isfunction(entity):
+ if hasattr(entity, '__defaults__'):
+ logging.log(3, 'Default args mapping: %s has: %s', entity,
+ entity.__defaults__)
compiled.__defaults__ = entity.__defaults__
+ else:
+ logging.log(3, 'Default args mapping: %s has no __defaults__', entity)
+
+ logging.log(3, 'Namespace of %s includes: %s', compiled,
+ compiled_module.__dict__.keys())
if hasattr(compiled, '__globals__'):
# Remove self to avoid circular references. This will probably only work
diff --git a/tensorflow/python/autograph/impl/api_test.py b/tensorflow/python/autograph/impl/api_test.py
index d5561ba..5192809 100644
--- a/tensorflow/python/autograph/impl/api_test.py
+++ b/tensorflow/python/autograph/impl/api_test.py
@@ -18,6 +18,7 @@
from __future__ import division
from __future__ import print_function
+import collections
import functools
import gc
@@ -26,6 +27,7 @@
from tensorflow.python.autograph import utils
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.impl import api
+from tensorflow.python.autograph.pyct import inspect_utils
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.utils import py_func
from tensorflow.python.framework import constant_op
@@ -46,7 +48,7 @@
class ApiTest(test.TestCase):
@test_util.run_deprecated_v1
- def test_decorator_recurses(self):
+ def test_decorator_recursive(self):
class TestClass(object):
@@ -69,7 +71,7 @@
self.assertListEqual([0, 1], self.evaluate(x).tolist())
@test_util.run_deprecated_v1
- def test_decorator_does_not_recurse(self):
+ def test_decorator_not_recursive(self):
class TestClass(object):
@@ -90,7 +92,7 @@
self.assertListEqual([0, 1], self.evaluate(x).tolist())
@test_util.run_deprecated_v1
- def test_decorator_calls_unconverted_graph(self):
+ def test_convert_then_do_not_convert_graph(self):
class TestClass(object):
@@ -105,14 +107,13 @@
return x
tc = TestClass()
- with self.cached_session() as sess:
- x = tc.test_method(
- constant_op.constant([2, 4]), constant_op.constant(1),
- constant_op.constant(-2))
- self.assertListEqual([0, 1], self.evaluate(x).tolist())
+ x = tc.test_method(
+ constant_op.constant((2, 4)), constant_op.constant(1),
+ constant_op.constant(-2))
+ self.assertAllEqual((0, 1), self.evaluate(x))
@test_util.run_deprecated_v1
- def test_decorator_calls_unconverted_py_func(self):
+ def test_convert_then_do_not_convert_py_func(self):
class TestClass(object):
@@ -132,11 +133,10 @@
return x
tc = TestClass()
- with self.cached_session() as sess:
- x = tc.test_method(
- constant_op.constant([2, 4]), constant_op.constant(1),
- constant_op.constant(-2))
- self.assertListEqual([0, 1], self.evaluate(x).tolist())
+ x = tc.test_method(
+ constant_op.constant((2, 4)), constant_op.constant(1),
+ constant_op.constant(-2))
+ self.assertAllEqual((0, 1), self.evaluate(x))
@test_util.run_deprecated_v1
def test_decorator_calls_decorated(self):
@@ -265,6 +265,26 @@
converter.ConversionOptions(), tc)
self.assertEqual(1, self.evaluate(x))
+ def test_converted_call_method_converts_recursively(self):
+
+ class TestClass(object):
+
+ def __init__(self, x):
+ self.x = x
+
+ def other_method(self):
+ if self.x < 0:
+ return -self.x
+ return self.x
+
+ def test_method(self):
+ return self.other_method()
+
+ tc = TestClass(constant_op.constant(-1))
+ x = api.converted_call(tc.test_method, None,
+ converter.ConversionOptions(recursive=True), tc)
+ self.assertEqual(1, self.evaluate(x))
+
def test_converted_call_method_by_class(self):
class TestClass(object):
@@ -334,6 +354,22 @@
constant_op.constant(0))
self.assertTrue(self.evaluate(x))
+ def test_converted_call_then_already_converted_dynamic(self):
+
+ @api.convert()
+ def g(x):
+ if x > 0:
+ return x
+ else:
+ return -x
+
+ def f(g, x):
+ return g(x)
+
+ x = api.converted_call(f, None, converter.ConversionOptions(),
+ g, constant_op.constant(1))
+ self.assertEqual(self.evaluate(x), 1)
+
@test_util.run_deprecated_v1
def test_converted_call_no_user_code(self):
@@ -397,6 +433,24 @@
self.evaluate(variables.global_variables_initializer())
self.assertAllEqual([[0.0, 0.0]], self.evaluate(x))
+ def test_converted_call_namedtuple(self):
+
+ opts = converter.ConversionOptions()
+
+ x = api.converted_call(collections.namedtuple, None, opts,
+ 'TestNamedtuple', ('a', 'b'))
+
+ self.assertTrue(inspect_utils.isnamedtuple(x))
+
+ def test_converted_call_namedtuple_via_collections(self):
+
+ opts = converter.ConversionOptions()
+
+ x = api.converted_call('namedtuple', collections, opts,
+ 'TestNamedtuple', ('a', 'b'))
+
+ self.assertTrue(inspect_utils.isnamedtuple(x))
+
def test_converted_call_lambda(self):
opts = converter.ConversionOptions()
diff --git a/tensorflow/python/autograph/impl/conversion.py b/tensorflow/python/autograph/impl/conversion.py
index 6912106..1ac2e33 100644
--- a/tensorflow/python/autograph/impl/conversion.py
+++ b/tensorflow/python/autograph/impl/conversion.py
@@ -20,6 +20,7 @@
import functools
import imp
+import unittest
import gast
@@ -80,29 +81,33 @@
m = functools
else:
m = tf_inspect.getmodule(o)
- if not hasattr(m, '__name__'):
- # Note: typically it's builtins that fall in this category. Builtins will
- # be handled by specific code that follows this screening layer.
- logging.log(2, '%s is NOT whitelisted: unknown module name', o)
- return False
- for prefix, in config.DEFAULT_UNCOMPILED_MODULES:
- if m.__name__.startswith(prefix):
- logging.log(2, '%s is whitelisted: name starts with "%s"', o, prefix)
+ if hasattr(m, '__name__'):
+ # Builtins typically have unnamed modules.
+ for prefix, in config.DEFAULT_UNCOMPILED_MODULES:
+ if m.__name__.startswith(prefix):
+ logging.log(2, 'Whitelisted: %s: name starts with "%s"', o, prefix)
+ return True
+
+ # Temporary -- whitelist tensorboard modules.
+ # TODO(b/122731813): Remove.
+ if m.__name__ == 'tensorboard' or '.tensorboard' in m.__name__:
+ logging.log(2, 'Whitelisted: %s: name contains "tensorboard"', o)
return True
if hasattr(o, 'autograph_info__') or hasattr(o, '__ag_compiled'):
- logging.log(2, '%s is whitelisted: already converted', o)
+ logging.log(2, 'Whitelisted: %s: already converted', o)
return True
- if (not inspect_utils.isweakrefself(o) and not tf_inspect.isclass(o) and
- hasattr(o, '__call__') and hasattr(o, '__class__')):
+ if hasattr(o, '__call__'):
# Callable objects: whitelisted if their __call__ method is.
- call_whitelisted = is_whitelisted_for_graph(o.__call__)
- if call_whitelisted:
- logging.log(2, '%s is whitelisted: object __call__ whitelisted', o)
- return call_whitelisted
+ # The type check avoids infinite recursion around the __call__ method
+ # of function objects.
+ if (type(o) != type(o.__call__)) and is_whitelisted_for_graph(o.__call__): # pylint: disable=unidiomatic-typecheck
+ logging.log(2, 'Whitelisted: %s: object __call__ whitelisted', o)
+ return True
+ owner_class = None
if tf_inspect.ismethod(o):
# Methods of whitelisted classes are also whitelisted, even if they are
# bound via user subclasses.
@@ -121,9 +126,13 @@
owner_class = inspect_utils.getmethodclass(o)
if owner_class is not None:
+ if issubclass(owner_class, unittest.TestCase):
+ logging.log(2, 'Whitelisted: %s: method of TestCase subclass', o)
+ return True
+
owner_class = inspect_utils.getdefiningclass(o, owner_class)
if is_whitelisted_for_graph(owner_class):
- logging.log(2, '%s is whitelisted: owner is whitelisted %s', o,
+ logging.log(2, 'Whitelisted: %s: owner is whitelisted %s', o,
owner_class)
return True
@@ -132,13 +141,14 @@
# because they don't expose source code. But we assume they are safe for
# graph mode since they are just containers.
if tf_inspect.isclass(o) and len(o.__bases__) > 1:
- logging.warn_first_n(
- 'Entity {} looks like a namedtuple subclass. If it has any custom'
- ' methods, they will not be converted by AutoGraph.'.format(o), 1)
- logging.log(2, '%s is whitelisted: named tuple', o)
+ logging.warn(
+ 'Entity {} looks like a namedtuple subclass. Its constructor will'
+ ' not be converted by AutoGraph, but if it has any custom methods,'
+ ' those will be.'.format(o), 1)
+ logging.log(2, 'Whitelisted: %s: named tuple', o)
return True
- logging.log(2, '%s is NOT whitelisted', o)
+ logging.log(2, 'Not whitelisted: %s: default rule', o)
return False
@@ -207,6 +217,9 @@
if logging.has_verbosity(2):
logging.log(2, 'Compiled output of %s:\n\n%s\n', o,
compiler.ast_to_source(node))
+ if logging.has_verbosity(4):
+ for n in node:
+ logging.log(4, 'Compiled AST of %s:\n\n%s\n', o, gast.dump(n))
if program_ctx.options.recursive:
while True:
diff --git a/tensorflow/python/autograph/impl/conversion_test.py b/tensorflow/python/autograph/impl/conversion_test.py
index cd893e3..ddda408 100644
--- a/tensorflow/python/autograph/impl/conversion_test.py
+++ b/tensorflow/python/autograph/impl/conversion_test.py
@@ -92,11 +92,9 @@
conversion.entity_to_graph(f, program_ctx, None, None)
self.assertTrue(f in program_ctx.dependency_cache)
- self.assertTrue(g in program_ctx.dependency_cache)
+ self.assertFalse(g in program_ctx.dependency_cache)
f_node = program_ctx.dependency_cache[f][0]
- g_node = program_ctx.dependency_cache[g][0]
self.assertEqual('tf__f', f_node.name)
- self.assertEqual('tf__g', g_node.name)
def test_entity_to_graph_class_hierarchy(self):
diff --git a/tensorflow/python/autograph/pyct/cfg.py b/tensorflow/python/autograph/pyct/cfg.py
index fdfcd4d..0cedfa8 100644
--- a/tensorflow/python/autograph/pyct/cfg.py
+++ b/tensorflow/python/autograph/pyct/cfg.py
@@ -393,6 +393,8 @@
def _connect_jump_to_finally_sections(self, node):
"""Connects a jump node to the finally sections protecting it."""
cursor = set((node,))
+ if node not in self.finally_sections:
+ return cursor
for guard_section_id in self.finally_sections[node]:
guard_begin, guard_ends = self.finally_section_subgraphs[guard_section_id]
self._connect_nodes(cursor, guard_begin)
@@ -620,10 +622,10 @@
leaving_node = self.lexical_scopes.pop()
assert node == leaving_node
- def _get_enclosing_scopes(self, include, stop_at):
+ def _get_enclosing_finally_scopes(self, stop_at):
included = []
for node in reversed(self.lexical_scopes):
- if isinstance(node, include):
+ if isinstance(node, gast.Try) and node.finalbody:
included.append(node)
if isinstance(node, stop_at):
return node, included
@@ -635,10 +637,8 @@
def _process_exit_statement(self, node, *exits_nodes_of_type):
# Note: this is safe because we process functions separately.
- try_node, guards = self._get_enclosing_scopes(
- include=(gast.Try,),
- stop_at=tuple(exits_nodes_of_type),
- )
+ try_node, guards = self._get_enclosing_finally_scopes(
+ tuple(exits_nodes_of_type))
if try_node is None:
raise ValueError(
'%s that is not enclosed by any of %s' % (node, exits_nodes_of_type))
@@ -646,10 +646,8 @@
def _process_continue_statement(self, node, *loops_to_nodes_of_type):
# Note: this is safe because we process functions separately.
- try_node, guards = self._get_enclosing_scopes(
- include=(gast.Try,),
- stop_at=tuple(loops_to_nodes_of_type),
- )
+ try_node, guards = self._get_enclosing_finally_scopes(
+ tuple(loops_to_nodes_of_type))
if try_node is None:
raise ValueError('%s that is not enclosed by any of %s' %
(node, loops_to_nodes_of_type))
@@ -698,10 +696,7 @@
self._process_basic_statement(node)
def visit_Raise(self, node):
- try_node, guards = self._get_enclosing_scopes(
- include=(gast.Try,),
- stop_at=(gast.FunctionDef,),
- )
+ try_node, guards = self._get_enclosing_finally_scopes((gast.FunctionDef,))
if try_node is None:
raise ValueError('%s that is not enclosed by any FunctionDef' % node)
self.builder.add_error_node(node, guards)
@@ -797,16 +792,13 @@
for stmt in node.orelse:
self.visit(stmt)
- if node.handlers:
- # TODO(mdan): Should we still support bare try/except? Might be confusing.
- raise NotImplementedError('exceptions are not yet supported')
-
self._exit_lexical_scope(node)
- self.builder.enter_finally_section(node)
- for stmt in node.finalbody:
- self.visit(stmt)
- self.builder.exit_finally_section(node)
+ if node.finalbody:
+ self.builder.enter_finally_section(node)
+ for stmt in node.finalbody:
+ self.visit(stmt)
+ self.builder.exit_finally_section(node)
def visit_With(self, node):
# TODO(mdan): Mark the context manager's exit call as exit guard.
diff --git a/tensorflow/python/autograph/pyct/inspect_utils.py b/tensorflow/python/autograph/pyct/inspect_utils.py
index 6d9bc43..eab01ee 100644
--- a/tensorflow/python/autograph/pyct/inspect_utils.py
+++ b/tensorflow/python/autograph/pyct/inspect_utils.py
@@ -31,7 +31,7 @@
# These functions test negative for isinstance(*, types.BuiltinFunctionType)
# and inspect.isbuiltin, and are generally not visible in globals().
-# TODO(mdan): Find a more generic way to test this - just enumerate __builtin__?
+# TODO(mdan): Remove this.
SPECIAL_BUILTINS = {
'dict': dict,
'enumerate': enumerate,
@@ -42,6 +42,7 @@
'print': print,
'range': range,
'tuple': tuple,
+ 'type': type,
'zip': zip
}
@@ -73,7 +74,7 @@
def isbuiltin(f):
"""Returns True if the argument is a built-in function."""
- if f in SPECIAL_BUILTINS.values():
+ if f in six.moves.builtins.__dict__.values():
return True
if isinstance(f, types.BuiltinFunctionType):
return True
@@ -125,6 +126,10 @@
if visited is None:
visited = set()
+ # Copy the dict to avoid "changed size error" during concurrent invocations.
+ # TODO(mdan): This is on the hot path. Can we avoid the copy?
+ namespace = dict(namespace)
+
for name in namespace:
# The value may be referenced by more than one symbol, case in which
# any symbol will be fine. If the program contains symbol aliases that
diff --git a/tensorflow/python/autograph/pyct/origin_info.py b/tensorflow/python/autograph/pyct/origin_info.py
index 230468e..f41b328 100644
--- a/tensorflow/python/autograph/pyct/origin_info.py
+++ b/tensorflow/python/autograph/pyct/origin_info.py
@@ -18,6 +18,7 @@
from __future__ import print_function
import collections
+import difflib
import os
import tokenize
@@ -27,6 +28,8 @@
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import ast_util
from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.pyct import pretty_printer
+from tensorflow.python.autograph.utils import ag_logging as logging
from tensorflow.python.util import tf_inspect
@@ -75,9 +78,11 @@
self.source_code_line)
def __repr__(self):
- return '{}:{}:{}'.format(
- os.path.split(self.loc.filename)[1], self.loc.lineno,
- self.loc.col_offset)
+ if self.loc.filename:
+ return '{}:{}:{}'.format(
+ os.path.split(self.loc.filename)[1], self.loc.lineno,
+ self.loc.col_offset)
+ return '<no file>:{}:{}'.format(self.loc.lineno, self.loc.col_offset)
# TODO(mdan): This source map should be a class - easier to refer to.
@@ -103,32 +108,47 @@
resolve(reparsed_nodes, code)
result = {}
- for before, after in ast_util.parallel_walk(nodes, reparsed_nodes):
- # Note: generated code might not be mapped back to its origin.
- # TODO(mdan): Generated code should always be mapped to something.
- origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None)
- final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None)
- if origin_info is None or final_info is None:
- continue
-
- line_loc = LineLocation(filename, final_info.loc.lineno)
-
- existing_origin = result.get(line_loc)
- if existing_origin is not None:
- # Overlaps may exist because of child nodes, but almost never to
- # different line locations. Exception make decorated functions, where
- # both lines are mapped to the same line in the AST.
-
- # Line overlaps: keep bottom node.
- if existing_origin.loc.line_loc == origin_info.loc.line_loc:
- if existing_origin.loc.lineno >= origin_info.loc.lineno:
- continue
-
- # In case of overlaps, keep the leftmost node.
- if existing_origin.loc.col_offset <= origin_info.loc.col_offset:
+ try:
+ for before, after in ast_util.parallel_walk(nodes, reparsed_nodes):
+ # Note: generated code might not be mapped back to its origin.
+ # TODO(mdan): Generated code should always be mapped to something.
+ origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None)
+ final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None)
+ if origin_info is None or final_info is None:
continue
- result[line_loc] = origin_info
+ line_loc = LineLocation(filename, final_info.loc.lineno)
+
+ existing_origin = result.get(line_loc)
+ if existing_origin is not None:
+ # Overlaps may exist because of child nodes, but almost never to
+ # different line locations. Exception make decorated functions, where
+ # both lines are mapped to the same line in the AST.
+
+ # Line overlaps: keep bottom node.
+ if existing_origin.loc.line_loc == origin_info.loc.line_loc:
+ if existing_origin.loc.lineno >= origin_info.loc.lineno:
+ continue
+
+ # In case of overlaps, keep the leftmost node.
+ if existing_origin.loc.col_offset <= origin_info.loc.col_offset:
+ continue
+
+ result[line_loc] = origin_info
+ except ValueError:
+ if logging.has_verbosity(3):
+ for n, rn in zip(nodes, reparsed_nodes):
+ nodes_str = pretty_printer.fmt(n, color=False, noanno=True)
+ reparsed_nodes_str = pretty_printer.fmt(rn, color=False, noanno=True)
+ diff = difflib.context_diff(
+ nodes_str.split('\n'),
+ reparsed_nodes_str.split('\n'),
+ fromfile='Original nodes',
+ tofile='Reparsed nodes',
+ n=7)
+ diff = '\n'.join(diff)
+ logging.log(3, 'AST seems to lack integrity. Diff:\n%s', diff)
+ raise
return result
diff --git a/tensorflow/python/autograph/pyct/parser.py b/tensorflow/python/autograph/pyct/parser.py
index 011d80d..8b73440 100644
--- a/tensorflow/python/autograph/pyct/parser.py
+++ b/tensorflow/python/autograph/pyct/parser.py
@@ -23,6 +23,7 @@
import re
import textwrap
+import threading
import gast
import six
@@ -30,10 +31,14 @@
from tensorflow.python.util import tf_inspect
+_parse_lock = threading.Lock() # Prevents linecache concurrency errors.
+
+
def parse_entity(entity):
"""Returns the AST of given entity."""
try:
- source = tf_inspect.getsource(entity)
+ with _parse_lock:
+ source = tf_inspect.getsource_no_unwrap(entity)
except (IOError, OSError) as e:
raise ValueError(
'Unable to locate the source code of {}. Note that functions defined'
diff --git a/tensorflow/python/autograph/pyct/pretty_printer.py b/tensorflow/python/autograph/pyct/pretty_printer.py
index bacc1e4..a92017f 100644
--- a/tensorflow/python/autograph/pyct/pretty_printer.py
+++ b/tensorflow/python/autograph/pyct/pretty_printer.py
@@ -25,10 +25,11 @@
class PrettyPrinter(gast.NodeVisitor):
"""Print AST nodes."""
- def __init__(self, color):
+ def __init__(self, color, noanno):
self.indent_lvl = 0
self.result = ''
self.color = color
+ self.noanno = noanno
def _color(self, string, color, attrs=None):
if self.color:
@@ -55,6 +56,15 @@
self.result += '\n'
def generic_visit(self, node, name=None):
+ # In very rare instances, a list can contain something other than a Node.
+ # e.g. Global contains a list of strings.
+ if isinstance(node, str):
+ if name:
+ self._print('%s%s="%s"' % (self._indent(), name, node))
+ else:
+ self._print('%s"%s"' % (self._indent(), node))
+ return
+
if node._fields:
cont = ':'
else:
@@ -68,6 +78,8 @@
self.indent_lvl += 1
for f in node._fields:
+ if self.noanno and f.startswith('__'):
+ continue
if not hasattr(node, f):
self._print('%s%s' % (self._indent(), self._warning('%s=<unset>' % f)))
continue
@@ -103,8 +115,8 @@
self.indent_lvl -= 1
-def fmt(node, color=True):
- printer = PrettyPrinter(color)
+def fmt(node, color=True, noanno=False):
+ printer = PrettyPrinter(color, noanno)
if isinstance(node, (list, tuple)):
for n in node:
printer.visit(n)
diff --git a/tensorflow/python/autograph/pyct/static_analysis/activity.py b/tensorflow/python/autograph/pyct/static_analysis/activity.py
index 4359e0a..65e4682 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/activity.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/activity.py
@@ -25,6 +25,7 @@
import weakref
import gast
+import six
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import qual_names
@@ -149,6 +150,14 @@
self.args = set()
+class _Comprehension(object):
+
+ no_root = True
+
+ def __init__(self):
+ self.targets = set()
+
+
class ActivityAnalyzer(transformer.Base):
"""Annotates nodes with local scope information.
@@ -199,12 +208,27 @@
if qn.owner_set & set(l.args):
return
+ # When inside a comprehension, ignore any of the comprehensions's targets.
+ # This includes attributes or slices of those arguments.
+ # This is not true in Python2, which leaks symbols.
+ if six.PY3:
+ for l in self.state[_Comprehension]:
+ if qn in l.targets:
+ return
+ if qn.owner_set & set(l.targets):
+ return
+
if isinstance(node.ctx, gast.Store):
- self.scope.mark_modified(qn)
- if qn.is_composite and composite_writes_alter_parent:
- self.scope.mark_modified(qn.parent)
- if self._in_aug_assign:
- self.scope.mark_read(qn)
+ # In comprehensions, modified symbols are the comprehension targets.
+ if six.PY3 and self.state[_Comprehension].level > 0:
+ # Like a lambda's args, they are tracked separately in Python3.
+ self.state[_Comprehension].targets.add(qn)
+ else:
+ self.scope.mark_modified(qn)
+ if qn.is_composite and composite_writes_alter_parent:
+ self.scope.mark_modified(qn.parent)
+ if self._in_aug_assign:
+ self.scope.mark_read(qn)
elif isinstance(node.ctx, gast.Load):
self.scope.mark_read(qn)
elif isinstance(node.ctx, gast.Param):
@@ -338,12 +362,41 @@
self.state[_Lambda].exit()
return node
+ def _process_iterable_comprehension(self, node):
+ # This handles ListComp, SetComp, GeneratorExp.
+ self.state[_Comprehension].enter()
+ # Note: it's important to visit the generators first to properly account
+ # for the variables local to these generators. Example: `x` is local to the
+ # expression `x for x in y`.
+ node.generators = self.visit_block(node.generators)
+ node.elt = self.visit(node.elt)
+ self.state[_Comprehension].exit()
+ return node
+
+ def visit_DictComp(self, node):
+ # Identical to _process_iterable_comprehension, different node names.
+ self.state[_Comprehension].enter()
+ node.generators = self.visit_block(node.generators)
+ node.key = self.visit(node.key)
+ node.value = self.visit(node.value)
+ self.state[_Comprehension].exit()
+ return node
+
+ def visit_ListComp(self, node):
+ return self._process_iterable_comprehension(node)
+
+ def visit_SetComp(self, node):
+ return self._process_iterable_comprehension(node)
+
+ def visit_GeneratorExp(self, node):
+ return self._process_iterable_comprehension(node)
+
def visit_arguments(self, node):
return self._process_statement(node)
def visit_FunctionDef(self, node):
# The FunctionDef node itself has a Scope object that tracks the creation
- # of its name, along with the usage of any decorator accompany it.
+ # of its name, along with the usage of any decorator accompanying it.
self._enter_scope(False)
node.decorator_list = self.visit_block(node.decorator_list)
self.scope.mark_modified(qual_names.QN(node.name))
diff --git a/tensorflow/python/autograph/pyct/static_analysis/liveness.py b/tensorflow/python/autograph/pyct/static_analysis/liveness.py
index f8b8d7f..691b786 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/liveness.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/liveness.py
@@ -219,6 +219,10 @@
frozenset(self.current_analyzer.out[cfg_node]))
return node
+ def visit_ExceptHandler(self, node):
+ # TODO(b/123995141) Add Exception Handlers to the CFG
+ return node
+
def resolve(node, source_info, graphs):
"""Resolves the live symbols at the exit of control flow statements.
diff --git a/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py b/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py
index 9738c6d..f14b1a3 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/liveness_test.py
@@ -18,6 +18,8 @@
from __future__ import division
from __future__ import print_function
+import six
+
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import cfg
from tensorflow.python.autograph.pyct import parser
@@ -243,6 +245,62 @@
self.assertHasLiveIn(fn_body[0], ('a', 'x', 'y'))
+ def test_live_in_generator_comprehension(self):
+
+ def test_fn(y):
+ if all(x for x in y):
+ return
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ if six.PY2:
+ self.assertHasLiveIn(fn_body[0], ('all', 'x', 'y'))
+ else:
+ self.assertHasLiveIn(fn_body[0], ('all', 'y'))
+
+ def test_live_in_list_comprehension(self):
+
+ def test_fn(y):
+ if [x for x in y]:
+ return
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ if six.PY2:
+ self.assertHasLiveIn(fn_body[0], ('x', 'y'))
+ else:
+ self.assertHasLiveIn(fn_body[0], ('y',))
+
+ def test_live_in_set_comprehension(self):
+
+ def test_fn(y):
+ if {x for x in y}:
+ return
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ if six.PY2:
+ self.assertHasLiveIn(fn_body[0], ('x', 'y'))
+ else:
+ self.assertHasLiveIn(fn_body[0], ('y',))
+
+ def test_live_in_dict_comprehension(self):
+
+ def test_fn(y):
+ if {k: v for k, v in y}:
+ return
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ if six.PY2:
+ self.assertHasLiveIn(fn_body[0], ('k', 'v', 'y'))
+ else:
+ self.assertHasLiveIn(fn_body[0], ('y',))
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions.py b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions.py
index d1587d8..6f0f09e 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions.py
@@ -223,6 +223,10 @@
def visit_global(self, node):
raise NotImplementedError()
+ def visit_ExceptHandler(self, node):
+ # TODO(b/123995141) Add Exception Handlers to the CFG
+ return node
+
def visit_Name(self, node):
if self.current_analyzer is None:
# Names may appear outside function defs - for example in class
@@ -232,7 +236,8 @@
analyzer = self.current_analyzer
cfg_node = self.current_cfg_node
- assert cfg_node is not None, 'name node outside of any statement?'
+ assert cfg_node is not None, ('name node, %s, outside of any statement?'
+ % node.id)
qn = anno.getanno(node, anno.Basic.QN)
if isinstance(node.ctx, gast.Load):
diff --git a/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py
index fd91721..848c546 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/reaching_definitions_test.py
@@ -18,6 +18,8 @@
from __future__ import division
from __future__ import print_function
+import six
+
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import cfg
from tensorflow.python.autograph.pyct import parser
@@ -294,6 +296,24 @@
self.assertNotSameDef(source, target)
self.assertSameDef(target, retval)
+ def test_comprehension_leaking(self):
+
+ def test_fn(a):
+ all(x for x in a)
+ return x # pylint:disable=undefined-variable
+
+ node = self._parse_and_analyze(test_fn)
+ fn_body = node.body[0].body
+
+ listcomp_target = fn_body[0].value.args[0].generators[0].target
+ retval = fn_body[1].value
+
+ # Python2 leaks comprehension symbols. Python3 doesn't.
+ if six.PY2:
+ self.assertSameDef(retval, listcomp_target)
+ else:
+ self.assertHasDefs(retval, 0)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/autograph/pyct/templates.py b/tensorflow/python/autograph/pyct/templates.py
index 0f0b861..831eb6d 100644
--- a/tensorflow/python/autograph/pyct/templates.py
+++ b/tensorflow/python/autograph/pyct/templates.py
@@ -91,6 +91,10 @@
self._ctx_override = None
return self.generic_visit(node)
+ def visit_comprehension(self, node):
+ self._ctx_override = None
+ return self.generic_visit(node)
+
class ReplaceTransformer(gast.NodeTransformer):
"""Replace AST nodes."""
diff --git a/tensorflow/python/autograph/pyct/templates_test.py b/tensorflow/python/autograph/pyct/templates_test.py
index cdb44b8..bd6b451 100644
--- a/tensorflow/python/autograph/pyct/templates_test.py
+++ b/tensorflow/python/autograph/pyct/templates_test.py
@@ -238,6 +238,16 @@
source = parser.parse_expression('[a(b(1))]')
templates.replace_as_expression(template, bar=source)
+ def test_star_comprehension_in_function_call(self):
+ template = """
+ a = foo(func, args)
+ """
+ source = parser.parse_expression('bar(*[i for i in range(j)])')
+ node = templates.replace(template, func=source.func, args=source.args)
+ arg_node = node[0].value.args[1].value
+ self.assertIsInstance(arg_node.generators[0].target.ctx, gast.Store)
+ self.assertIsInstance(arg_node.elt.ctx, gast.Load)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/autograph/utils/ag_logging.py b/tensorflow/python/autograph/utils/ag_logging.py
index 847000a..cd737a8 100644
--- a/tensorflow/python/autograph/utils/ag_logging.py
+++ b/tensorflow/python/autograph/utils/ag_logging.py
@@ -110,7 +110,7 @@
global verbosity_level
if verbosity_level is not None:
return verbosity_level
- return os.getenv(VERBOSITY_VAR_NAME, DEFAULT_VERBOSITY)
+ return int(os.getenv(VERBOSITY_VAR_NAME, DEFAULT_VERBOSITY))
def has_verbosity(level):
@@ -131,5 +131,9 @@
print(msg % args)
+def warn(msg, *args, **kwargs):
+ logging.warn(msg, *args, **kwargs)
+
+
def warn_first_n(msg, *args, **kwargs):
logging.log_first_n(logging.WARN, msg, *args, **kwargs)
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index 4f84bcb..3ba5468 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -27,7 +27,7 @@
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2019, 1, 30)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2019, 2, 6)
@tf_export("compat.forward_compatible")
diff --git a/tensorflow/python/compiler/tensorrt/BUILD b/tensorflow/python/compiler/tensorrt/BUILD
new file mode 100644
index 0000000..a382579
--- /dev/null
+++ b/tensorflow/python/compiler/tensorrt/BUILD
@@ -0,0 +1,176 @@
+# Description:
+# Wrap NVIDIA TensorRT (http://developer.nvidia.com/tensorrt) with tensorflow
+# and provide TensorRT operators and converter package.
+# APIs are meant to change over time.
+
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_copts",
+)
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
+load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
+load(
+ "@local_config_tensorrt//:build_defs.bzl",
+ "if_tensorrt",
+)
+
+exports_files(glob([
+ "test/testdata/*",
+]))
+
+py_library(
+ name = "init_py",
+ srcs = ["__init__.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":tf_trt_integration_test_base",
+ ":trt_convert_py",
+ ":trt_ops_py",
+ "//tensorflow/python:errors",
+ ],
+)
+
+py_library(
+ name = "trt_ops_py",
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/compiler/tf2tensorrt:trt_ops",
+ "//tensorflow/compiler/tf2tensorrt:trt_ops_loader",
+ ],
+)
+
+py_library(
+ name = "trt_convert_py",
+ srcs = ["trt_convert.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":wrap_conversion",
+ "//tensorflow/python:graph_util",
+ "//tensorflow/python:session",
+ "//tensorflow/python:tf_optimizer",
+ "//tensorflow/python/saved_model:builder",
+ "//tensorflow/python/saved_model:loader",
+ "//tensorflow/python/saved_model:tag_constants",
+ ],
+)
+
+# TODO(aaroey): this wrapper has been causing troubles of double linking, so
+# either get rid of it, or split to make it contain minimum dependencies.
+tf_py_wrap_cc(
+ name = "wrap_conversion",
+ srcs = ["trt_conversion.i"],
+ copts = tf_copts(),
+ swig_includes = [
+ "//tensorflow/python:platform/base.i",
+ ],
+ deps = [
+ "//tensorflow/compiler/tf2tensorrt:test_utils",
+ "//tensorflow/compiler/tf2tensorrt:trt_conversion",
+ "//tensorflow/compiler/tf2tensorrt:trt_op_kernels",
+ "//third_party/python_runtime:headers",
+ ],
+)
+
+py_library(
+ name = "tf_trt_integration_test_base",
+ srcs = ["test/tf_trt_integration_test_base.py"],
+ deps = [
+ ":trt_convert_py",
+ ":trt_ops_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_test_lib",
+ ],
+)
+
+cuda_py_test(
+ name = "trt_convert_test",
+ srcs = ["trt_convert_test.py"],
+ additional_deps = [
+ ":trt_convert_py",
+ ":trt_ops_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:graph_util",
+ "//tensorflow/python/saved_model:builder",
+ "//tensorflow/python/saved_model:loader",
+ "//tensorflow/python/saved_model:signature_constants",
+ "//tensorflow/python/saved_model:signature_def_utils",
+ "//tensorflow/python/saved_model:tag_constants",
+ "//tensorflow/python/saved_model:utils",
+ "//tensorflow/python/tools:freeze_graph_lib",
+ "//tensorflow/python/tools:saved_model_utils",
+ ],
+ tags = [
+ "no_cuda_on_cpu_tap",
+ "no_windows",
+ "nomac",
+ ],
+)
+
+cuda_py_tests(
+ name = "tf_trt_integration_test",
+ srcs = [
+ "test/base_test.py",
+ "test/batch_matmul_test.py",
+ "test/biasadd_matmul_test.py",
+ "test/binary_tensor_weight_broadcast_test.py",
+ "test/concatenation_test.py",
+ "test/const_broadcast_test.py",
+ "test/conv2d_test.py",
+ "test/dynamic_input_shapes_test.py",
+ "test/identity_output_test.py",
+ "test/int32_test.py",
+ "test/lru_cache_test.py",
+ "test/memory_alignment_test.py",
+ "test/multi_connection_neighbor_engine_test.py",
+ "test/neighboring_engine_test.py",
+ "test/quantization_test.py",
+ "test/rank_two_test.py",
+ "test/reshape_transpose_test.py",
+ "test/topk_test.py",
+ "test/unary_test.py",
+ "test/vgg_block_nchw_test.py",
+ "test/vgg_block_test.py",
+ ],
+ additional_deps = [
+ ":tf_trt_integration_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_test_lib",
+ ],
+ tags = [
+ "no_cuda_on_cpu_tap",
+ "no_windows",
+ "nomac",
+ ],
+)
+
+cuda_py_test(
+ name = "quantization_mnist_test",
+ srcs = ["test/quantization_mnist_test.py"],
+ additional_deps = [
+ ":tf_trt_integration_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python/keras:keras",
+ "//tensorflow/python/estimator:estimator",
+ ],
+ data = [
+ "test/testdata/checkpoint",
+ "test/testdata/model.ckpt-46900.data-00000-of-00001",
+ "test/testdata/model.ckpt-46900.index",
+ ],
+ tags = [
+ "no_cuda_on_cpu_tap",
+ "no_pip",
+ "no_tap", # It is not able to download the mnist data.
+ "no_windows",
+ "nomac",
+ ],
+)
diff --git a/tensorflow/contrib/tensorrt/README.md b/tensorflow/python/compiler/tensorrt/README.md
similarity index 100%
rename from tensorflow/contrib/tensorrt/README.md
rename to tensorflow/python/compiler/tensorrt/README.md
diff --git a/tensorflow/python/compiler/tensorrt/__init__.py b/tensorflow/python/compiler/tensorrt/__init__.py
new file mode 100644
index 0000000..88fb691
--- /dev/null
+++ b/tensorflow/python/compiler/tensorrt/__init__.py
@@ -0,0 +1,42 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Exposes the python wrapper for TensorRT graph transforms."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import errors
+
+# pylint: disable=unused-import,g-import-not-at-top,line-too-long
+try:
+ from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops
+ from tensorflow.python.compiler.tensorrt.trt_convert import add_test_value
+ from tensorflow.python.compiler.tensorrt.trt_convert import calib_graph_to_infer_graph
+ from tensorflow.python.compiler.tensorrt.trt_convert import clear_test_values
+ from tensorflow.python.compiler.tensorrt.trt_convert import create_inference_graph
+ from tensorflow.python.compiler.tensorrt.trt_convert import enable_test_value
+ from tensorflow.python.compiler.tensorrt.trt_convert import get_test_value
+ from tensorflow.python.compiler.tensorrt.trt_convert import is_tensorrt_enabled
+except errors.NotFoundError as e:
+ no_trt_message = (
+ '**** Failed to initialize TensorRT. This is either because the TensorRT'
+ ' installation path is not in LD_LIBRARY_PATH, or because you do not have'
+ ' it installed. If not installed, please go to'
+ ' https://developer.nvidia.com/tensorrt to download and install'
+ ' TensorRT ****')
+ print(no_trt_message)
+ raise e
+# pylint: enable=unused-import,g-import-not-at-top,line-too-long
diff --git a/tensorflow/contrib/tensorrt/test/base_test.py b/tensorflow/python/compiler/tensorrt/test/base_test.py
similarity index 98%
rename from tensorflow/contrib/tensorrt/test/base_test.py
rename to tensorflow/python/compiler/tensorrt/test/base_test.py
index 17e0b6f..cc31099 100644
--- a/tensorflow/contrib/tensorrt/test/base_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/base_test.py
@@ -20,8 +20,8 @@
import numpy as np
-from tensorflow.contrib.tensorrt.python import trt_convert
-from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
+from tensorflow.python.compiler.tensorrt import trt_convert
+from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
diff --git a/tensorflow/contrib/tensorrt/test/batch_matmul_test.py b/tensorflow/python/compiler/tensorrt/test/batch_matmul_test.py
similarity index 96%
rename from tensorflow/contrib/tensorrt/test/batch_matmul_test.py
rename to tensorflow/python/compiler/tensorrt/test/batch_matmul_test.py
index 46e3407..49ad09a 100644
--- a/tensorflow/contrib/tensorrt/test/batch_matmul_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/batch_matmul_test.py
@@ -20,7 +20,7 @@
import numpy as np
-from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
+from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
diff --git a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py b/tensorflow/python/compiler/tensorrt/test/biasadd_matmul_test.py
similarity index 97%
rename from tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py
rename to tensorflow/python/compiler/tensorrt/test/biasadd_matmul_test.py
index ca23629..2b7bbbc 100644
--- a/tensorflow/contrib/tensorrt/test/biasadd_matmul_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/biasadd_matmul_test.py
@@ -20,7 +20,7 @@
import numpy as np
-from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
+from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
diff --git a/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py b/tensorflow/python/compiler/tensorrt/test/binary_tensor_weight_broadcast_test.py
similarity index 96%
rename from tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py
rename to tensorflow/python/compiler/tensorrt/test/binary_tensor_weight_broadcast_test.py
index 846fd00..7e1d3af 100644
--- a/tensorflow/contrib/tensorrt/test/binary_tensor_weight_broadcast_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/binary_tensor_weight_broadcast_test.py
@@ -20,7 +20,7 @@
import numpy as np
-from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
+from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
diff --git a/tensorflow/contrib/tensorrt/test/concatenation_test.py b/tensorflow/python/compiler/tensorrt/test/concatenation_test.py
similarity index 96%
rename from tensorflow/contrib/tensorrt/test/concatenation_test.py
rename to tensorflow/python/compiler/tensorrt/test/concatenation_test.py
index 5d8742a..f30324e 100644
--- a/tensorflow/contrib/tensorrt/test/concatenation_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/concatenation_test.py
@@ -20,7 +20,7 @@
import numpy as np
-from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
+from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
diff --git a/tensorflow/contrib/tensorrt/test/const_broadcast_test.py b/tensorflow/python/compiler/tensorrt/test/const_broadcast_test.py
similarity index 96%
rename from tensorflow/contrib/tensorrt/test/const_broadcast_test.py
rename to tensorflow/python/compiler/tensorrt/test/const_broadcast_test.py
index 9137d00..2d76466 100644
--- a/tensorflow/contrib/tensorrt/test/const_broadcast_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/const_broadcast_test.py
@@ -20,7 +20,7 @@
import numpy as np
-from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
+from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
diff --git a/tensorflow/contrib/tensorrt/test/conv2d_test.py b/tensorflow/python/compiler/tensorrt/test/conv2d_test.py
similarity index 79%
rename from tensorflow/contrib/tensorrt/test/conv2d_test.py
rename to tensorflow/python/compiler/tensorrt/test/conv2d_test.py
index e7993b4..326cad5 100644
--- a/tensorflow/contrib/tensorrt/test/conv2d_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/conv2d_test.py
@@ -20,12 +20,13 @@
import numpy as np
-from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
+from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_nn_ops
+from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import test
@@ -187,5 +188,46 @@
return ["TRTEngineOp_0"]
+class Conv2DTranposeTest(trt_test.TfTrtIntegrationTestBase):
+
+ def GetParams(self):
+ """Testing conversion of conv2d_transpose (AKA Conv2DBackpropInput)"""
+ np.random.seed(1234)
+ dtype = dtypes.float32
+ input_name = "input"
+ n, c, h, w = 13, 3, 7, 11
+ num_filters = 8
+ input_dims = [n, c, h, w]
+ output_name = "output"
+ g = ops.Graph()
+ with g.as_default():
+ inp = array_ops.placeholder(
+ dtype=dtype, shape=[None] + input_dims[1:], name=input_name)
+ with g.device("/GPU:0"):
+ weights_shape = [2, 2, num_filters, c]
+ weights = constant_op.constant(
+ np.random.randn(*weights_shape), dtype=dtype)
+ output_shape = constant_op.constant([n, num_filters, h * 2, w * 2],
+ dtype=dtypes.int32)
+ output = nn_ops.conv2d_transpose(
+ inp,
+ weights,
+ output_shape,
+ strides=[1, 1, 2, 2],
+ padding="SAME",
+ data_format="NCHW")
+ output = array_ops.identity(output, name=output_name)
+ return trt_test.TfTrtIntegrationTestParams(
+ gdef=g.as_graph_def(),
+ input_names=[input_name],
+ input_dims=[[input_dims]],
+ output_names=[output_name],
+ expected_output_dims=[[[n, num_filters, h * 2, w * 2]]])
+
+ def ExpectedEnginesToBuild(self, run_params):
+ """Return the expected engines to build."""
+ return ["TRTEngineOp_0"]
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/tensorrt/test/dynamic_input_shapes_test.py b/tensorflow/python/compiler/tensorrt/test/dynamic_input_shapes_test.py
similarity index 97%
rename from tensorflow/contrib/tensorrt/test/dynamic_input_shapes_test.py
rename to tensorflow/python/compiler/tensorrt/test/dynamic_input_shapes_test.py
index cc28cd6..cb358d4 100644
--- a/tensorflow/contrib/tensorrt/test/dynamic_input_shapes_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/dynamic_input_shapes_test.py
@@ -20,7 +20,7 @@
import numpy as np
-from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
+from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
diff --git a/tensorflow/contrib/tensorrt/test/identity_output_test.py b/tensorflow/python/compiler/tensorrt/test/identity_output_test.py
similarity index 96%
rename from tensorflow/contrib/tensorrt/test/identity_output_test.py
rename to tensorflow/python/compiler/tensorrt/test/identity_output_test.py
index b568eed..23a72c5 100644
--- a/tensorflow/contrib/tensorrt/test/identity_output_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/identity_output_test.py
@@ -25,7 +25,7 @@
import numpy as np
-from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
+from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
diff --git a/tensorflow/contrib/tensorrt/test/int32_test.py b/tensorflow/python/compiler/tensorrt/test/int32_test.py
similarity index 96%
rename from tensorflow/contrib/tensorrt/test/int32_test.py
rename to tensorflow/python/compiler/tensorrt/test/int32_test.py
index 8cf5387..6d44469 100644
--- a/tensorflow/contrib/tensorrt/test/int32_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/int32_test.py
@@ -20,7 +20,7 @@
import numpy as np
-from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
+from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
diff --git a/tensorflow/contrib/tensorrt/test/lru_cache_test.py b/tensorflow/python/compiler/tensorrt/test/lru_cache_test.py
similarity index 96%
rename from tensorflow/contrib/tensorrt/test/lru_cache_test.py
rename to tensorflow/python/compiler/tensorrt/test/lru_cache_test.py
index 7702413..18e6d32 100644
--- a/tensorflow/contrib/tensorrt/test/lru_cache_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/lru_cache_test.py
@@ -20,7 +20,7 @@
import numpy as np
-from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
+from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
diff --git a/tensorflow/contrib/tensorrt/test/memory_alignment_test.py b/tensorflow/python/compiler/tensorrt/test/memory_alignment_test.py
similarity index 96%
rename from tensorflow/contrib/tensorrt/test/memory_alignment_test.py
rename to tensorflow/python/compiler/tensorrt/test/memory_alignment_test.py
index cc64329..89625aa 100644
--- a/tensorflow/contrib/tensorrt/test/memory_alignment_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/memory_alignment_test.py
@@ -20,7 +20,7 @@
import numpy as np
-from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
+from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
diff --git a/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py b/tensorflow/python/compiler/tensorrt/test/multi_connection_neighbor_engine_test.py
similarity index 96%
rename from tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py
rename to tensorflow/python/compiler/tensorrt/test/multi_connection_neighbor_engine_test.py
index a14bb03..d04c695 100644
--- a/tensorflow/contrib/tensorrt/test/multi_connection_neighbor_engine_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/multi_connection_neighbor_engine_test.py
@@ -20,7 +20,7 @@
import numpy as np
-from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
+from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
diff --git a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py b/tensorflow/python/compiler/tensorrt/test/neighboring_engine_test.py
similarity index 96%
rename from tensorflow/contrib/tensorrt/test/neighboring_engine_test.py
rename to tensorflow/python/compiler/tensorrt/test/neighboring_engine_test.py
index 06a86bb..1f7189f 100644
--- a/tensorflow/contrib/tensorrt/test/neighboring_engine_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/neighboring_engine_test.py
@@ -20,7 +20,7 @@
import numpy as np
-from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
+from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
diff --git a/tensorflow/contrib/tensorrt/test/quantization_mnist_test.py b/tensorflow/python/compiler/tensorrt/test/quantization_mnist_test.py
similarity index 98%
rename from tensorflow/contrib/tensorrt/test/quantization_mnist_test.py
rename to tensorflow/python/compiler/tensorrt/test/quantization_mnist_test.py
index d68211a..1d7792c 100644
--- a/tensorflow/contrib/tensorrt/test/quantization_mnist_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/quantization_mnist_test.py
@@ -21,10 +21,10 @@
# pylint: disable=unused-import
from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops
# pylint: enable=unused-import
-from tensorflow.contrib.tensorrt.python import trt_convert
from tensorflow.core.protobuf import config_pb2
from tensorflow.python import data
from tensorflow.python import keras
+from tensorflow.python.compiler.tensorrt import trt_convert
from tensorflow.python.estimator.estimator import Estimator
from tensorflow.python.estimator.model_fn import EstimatorSpec
from tensorflow.python.estimator.model_fn import ModeKeys
@@ -265,7 +265,7 @@
def testEval(self):
if not trt_convert.is_tensorrt_enabled():
return
- model_dir = test.test_src_dir_path('contrib/tensorrt/test/testdata')
+ model_dir = test.test_src_dir_path('python/compiler/tensorrt/test/testdata')
accuracy_tf_native = self._Run(
is_training=False,
diff --git a/tensorflow/contrib/tensorrt/test/quantization_test.py b/tensorflow/python/compiler/tensorrt/test/quantization_test.py
similarity index 96%
rename from tensorflow/contrib/tensorrt/test/quantization_test.py
rename to tensorflow/python/compiler/tensorrt/test/quantization_test.py
index ce1b25e..086e070 100644
--- a/tensorflow/contrib/tensorrt/test/quantization_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/quantization_test.py
@@ -20,8 +20,8 @@
import numpy as np
-from tensorflow.contrib.tensorrt.python import trt_convert
-from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
+from tensorflow.python.compiler.tensorrt import trt_convert
+from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
diff --git a/tensorflow/contrib/tensorrt/test/rank_two_test.py b/tensorflow/python/compiler/tensorrt/test/rank_two_test.py
similarity index 96%
rename from tensorflow/contrib/tensorrt/test/rank_two_test.py
rename to tensorflow/python/compiler/tensorrt/test/rank_two_test.py
index 97159bb..a951638 100644
--- a/tensorflow/contrib/tensorrt/test/rank_two_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/rank_two_test.py
@@ -18,7 +18,7 @@
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
+from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
diff --git a/tensorflow/contrib/tensorrt/test/reshape_transpose_test.py b/tensorflow/python/compiler/tensorrt/test/reshape_transpose_test.py
similarity index 98%
rename from tensorflow/contrib/tensorrt/test/reshape_transpose_test.py
rename to tensorflow/python/compiler/tensorrt/test/reshape_transpose_test.py
index 7fb2cbd..423d70f 100644
--- a/tensorflow/contrib/tensorrt/test/reshape_transpose_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/reshape_transpose_test.py
@@ -18,7 +18,7 @@
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
+from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
diff --git a/tensorflow/contrib/tensorrt/test/testdata/checkpoint b/tensorflow/python/compiler/tensorrt/test/testdata/checkpoint
similarity index 100%
rename from tensorflow/contrib/tensorrt/test/testdata/checkpoint
rename to tensorflow/python/compiler/tensorrt/test/testdata/checkpoint
diff --git a/tensorflow/contrib/tensorrt/test/testdata/model.ckpt-46900.data-00000-of-00001 b/tensorflow/python/compiler/tensorrt/test/testdata/model.ckpt-46900.data-00000-of-00001
similarity index 100%
rename from tensorflow/contrib/tensorrt/test/testdata/model.ckpt-46900.data-00000-of-00001
rename to tensorflow/python/compiler/tensorrt/test/testdata/model.ckpt-46900.data-00000-of-00001
Binary files differ
diff --git a/tensorflow/contrib/tensorrt/test/testdata/model.ckpt-46900.index b/tensorflow/python/compiler/tensorrt/test/testdata/model.ckpt-46900.index
similarity index 100%
rename from tensorflow/contrib/tensorrt/test/testdata/model.ckpt-46900.index
rename to tensorflow/python/compiler/tensorrt/test/testdata/model.ckpt-46900.index
Binary files differ
diff --git a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py
similarity index 99%
rename from tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
rename to tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py
index 9a00cdb..28563f0 100644
--- a/tensorflow/contrib/tensorrt/test/tf_trt_integration_test_base.py
+++ b/tensorflow/python/compiler/tensorrt/test/tf_trt_integration_test_base.py
@@ -28,9 +28,9 @@
# pylint: disable=unused-import
from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops
# pylint: enable=unused-import
-from tensorflow.contrib.tensorrt.python import trt_convert
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.compiler.tensorrt import trt_convert
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_io
from tensorflow.python.framework import importer
@@ -259,7 +259,7 @@
"""Get config proto based on specific settings."""
conversion_params = self.GetConversionParams(run_params)
if graph_state != GraphState.ORIGINAL and run_params.use_optimizer:
- rewriter_cfg = trt_convert.get_tensorrt_rewriter_config(
+ rewriter_cfg = trt_convert.TrtGraphConverter.get_tensorrt_rewriter_config(
conversion_params.rewriter_config, conversion_params.max_batch_size,
conversion_params.max_workspace_size_bytes,
conversion_params.precision_mode,
diff --git a/tensorflow/contrib/tensorrt/test/topk_test.py b/tensorflow/python/compiler/tensorrt/test/topk_test.py
similarity index 97%
rename from tensorflow/contrib/tensorrt/test/topk_test.py
rename to tensorflow/python/compiler/tensorrt/test/topk_test.py
index 5524fcd..1e2bf3b 100644
--- a/tensorflow/contrib/tensorrt/test/topk_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/topk_test.py
@@ -18,7 +18,7 @@
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
+from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import constant_op
diff --git a/tensorflow/contrib/tensorrt/test/unary_test.py b/tensorflow/python/compiler/tensorrt/test/unary_test.py
similarity index 97%
rename from tensorflow/contrib/tensorrt/test/unary_test.py
rename to tensorflow/python/compiler/tensorrt/test/unary_test.py
index 497ea28..83569bc 100644
--- a/tensorflow/contrib/tensorrt/test/unary_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/unary_test.py
@@ -20,7 +20,7 @@
import numpy as np
-from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
+from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
diff --git a/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py b/tensorflow/python/compiler/tensorrt/test/vgg_block_nchw_test.py
similarity index 96%
rename from tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py
rename to tensorflow/python/compiler/tensorrt/test/vgg_block_nchw_test.py
index b5fed73..97ee117 100644
--- a/tensorflow/contrib/tensorrt/test/vgg_block_nchw_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/vgg_block_nchw_test.py
@@ -20,7 +20,7 @@
import numpy as np
-from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
+from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
diff --git a/tensorflow/contrib/tensorrt/test/vgg_block_test.py b/tensorflow/python/compiler/tensorrt/test/vgg_block_test.py
similarity index 96%
rename from tensorflow/contrib/tensorrt/test/vgg_block_test.py
rename to tensorflow/python/compiler/tensorrt/test/vgg_block_test.py
index 307128f..a4fa1d6 100644
--- a/tensorflow/contrib/tensorrt/test/vgg_block_test.py
+++ b/tensorflow/python/compiler/tensorrt/test/vgg_block_test.py
@@ -20,7 +20,7 @@
import numpy as np
-from tensorflow.contrib.tensorrt.test import tf_trt_integration_test_base as trt_test
+from tensorflow.python.compiler.tensorrt.test import tf_trt_integration_test_base as trt_test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
diff --git a/tensorflow/contrib/tensorrt/trt_conversion.i b/tensorflow/python/compiler/tensorrt/trt_conversion.i
similarity index 100%
rename from tensorflow/contrib/tensorrt/trt_conversion.i
rename to tensorflow/python/compiler/tensorrt/trt_conversion.i
diff --git a/tensorflow/python/compiler/tensorrt/trt_convert.py b/tensorflow/python/compiler/tensorrt/trt_convert.py
new file mode 100644
index 0000000..33b5e50
--- /dev/null
+++ b/tensorflow/python/compiler/tensorrt/trt_convert.py
@@ -0,0 +1,672 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Exposes the Python wrapper conversion to trt_graph."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import six as _six
+# pylint: disable=unused-import,line-too-long
+from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops
+from tensorflow.python.compiler.tensorrt.wrap_conversion import add_test_value
+from tensorflow.python.compiler.tensorrt.wrap_conversion import calib_convert
+from tensorflow.python.compiler.tensorrt.wrap_conversion import clear_test_values
+from tensorflow.python.compiler.tensorrt.wrap_conversion import enable_test_value
+from tensorflow.python.compiler.tensorrt.wrap_conversion import get_linked_tensorrt_version
+from tensorflow.python.compiler.tensorrt.wrap_conversion import get_loaded_tensorrt_version
+from tensorflow.python.compiler.tensorrt.wrap_conversion import get_test_value
+from tensorflow.python.compiler.tensorrt.wrap_conversion import is_tensorrt_enabled
+# pylint: enable=unused-import,line-too-long
+from tensorflow.core.framework import graph_pb2
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.core.protobuf import meta_graph_pb2
+from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.client import session
+from tensorflow.python.framework import errors_impl as _impl
+from tensorflow.python.framework import graph_util
+from tensorflow.python.framework import importer
+from tensorflow.python.framework import ops
+from tensorflow.python.grappler import tf_optimizer
+from tensorflow.python.platform import tf_logging
+from tensorflow.python.saved_model import builder
+from tensorflow.python.saved_model import loader
+from tensorflow.python.saved_model import tag_constants
+from tensorflow.python.training import saver
+
+
+def _to_bytes(s):
+ """Encode s if it is a sequence of chars."""
+ if isinstance(s, _six.text_type):
+ return s.encode("utf-8", errors="surrogateescape")
+ return s
+
+
+def _to_string(s):
+ """Decode s if it is a sequence of bytes."""
+ if isinstance(s, _six.binary_type):
+ return s.decode("utf-8")
+ return s
+
+
+class GraphConverter(object):
+ """Base class for offline converters to optimize SavedModels/GraphDefs.
+
+ A `GraphConverter` object encapsulates the environment to convert (optimize) a
+ TensorFlow SavedModel or GraphDef.
+
+ To create a custom GraphConverter:
+
+ ```python
+ class MyGraphConverter(GraphConverter):
+ ...
+
+ def get_rewriter_config(self, rewriter_config_template=None):
+ my_rewriter_config = ...
+ return my_rewriter_config
+ ```
+
+ Then to run the conversion without quantization calibration:
+
+ ```python
+ my_converter = MyGraphConverter(input_saved_model_dir="my_dir")
+ converted_graph_def = my_converter.convert()
+ my_converter.save(output_saved_model_dir) # Optional
+ ```
+
+ TODO(laigd): add calibration support.
+ """
+
+ def __init__(self,
+ input_saved_model_dir=None,
+ input_saved_model_tags=None,
+ input_graph_def=None,
+ nodes_blacklist=None,
+ session_config=None):
+ """Initialize the converter.
+
+ Args:
+ input_saved_model_dir: the directory to load the SavedModel which contains
+ the input graph to transforms. Used only when input_graph_def is None.
+ input_saved_model_tags: list of tags to load the SavedModel.
+ input_graph_def: a GraphDef object containing a model to be transformed.
+ If set to None, the graph will be read from the SavedModel loaded from
+ input_saved_model_dir.
+ nodes_blacklist: list of node names to prevent the converter from
+ touching. Only used when input_graph_def is not None.
+ session_config: the ConfigProto used to create a Session. It's also used
+ as a template to create a RewriterConfig for conversion. If not
+ specified, a default ConfigProto will be used.
+
+ Raises:
+ ValueError: if the combination of the parameters is invalid.
+ """
+ if input_graph_def and input_saved_model_dir:
+ raise ValueError(
+ "Can only specify one of input_graph_def and input_saved_model_dir")
+ if not input_graph_def and not input_saved_model_dir:
+ raise ValueError("Must specify one of input_graph_def and "
+ "input_saved_model_dir")
+
+ self._input_graph_def = input_graph_def
+ self._nodes_blacklist = nodes_blacklist
+ self._input_saved_model_dir = input_saved_model_dir
+ self._converted = False
+ self._grappler_meta_graph_def = None
+
+ self._input_saved_model_tags = (
+ input_saved_model_tags or [tag_constants.SERVING])
+ self._session_config = session_config or config_pb2.ConfigProto()
+
+ def get_rewriter_config(self, rewriter_config_template=None):
+ """Returns a RewriterConfig proto for TRT transformation.
+
+ Args:
+ rewriter_config_template: a template RewriterConfig proto used to create a
+ RewriterConfig for the conversion. The implementation should not modify
+ the template. If None, it will use a default one.
+
+ Returns:
+ A RewriterConfig proto which will be used to run the conversion using
+ Grappler.
+ """
+ raise NotImplementedError("get_rewriter_config")
+
+ def _run_conversion(self):
+ """Run Grappler's OptimizeGraph() tool to convert the graph."""
+ # Create custom ConfigProto for Grappler.
+ grappler_session_config = config_pb2.ConfigProto()
+ grappler_session_config.CopyFrom(self._session_config)
+ rewriter_config = None
+ if (grappler_session_config.HasField("graph_options") and
+ grappler_session_config.graph_options.HasField("rewrite_options")):
+ rewriter_config = grappler_session_config.graph_options.rewrite_options
+ custom_rewriter_config = self.get_rewriter_config(rewriter_config)
+ grappler_session_config.graph_options.rewrite_options.CopyFrom(
+ custom_rewriter_config)
+
+ # Run Grappler.
+ self._converted_graph_def = tf_optimizer.OptimizeGraph(
+ grappler_session_config,
+ self._grappler_meta_graph_def,
+ graph_id=b"tf_graph")
+ self._converted = True
+
+ def _convert_graph_def(self):
+ """Convert the input GraphDef."""
+ graph = ops.Graph()
+ with graph.as_default():
+ importer.import_graph_def(self._input_graph_def, name="")
+ self._grappler_meta_graph_def = saver.export_meta_graph(
+ graph_def=graph.as_graph_def(add_shapes=True), graph=graph)
+ if self._nodes_blacklist:
+ output_collection = meta_graph_pb2.CollectionDef()
+ output_list = output_collection.node_list.value
+ for i in self._nodes_blacklist:
+ if isinstance(i, ops.Tensor):
+ output_list.append(_to_bytes(i.name))
+ else:
+ output_list.append(_to_bytes(i))
+ # TODO(laigd): use another key as the self._nodes_blacklist are really
+ # not train_op.
+ self._grappler_meta_graph_def.collection_def["train_op"].CopyFrom(
+ output_collection)
+
+ self._run_conversion()
+
+ def _convert_saved_model(self):
+ """Convert the input SavedModel."""
+ graph = ops.Graph()
+ with session.Session(graph=graph, config=self._session_config) as sess:
+ input_meta_graph_def = loader.load(sess, self._input_saved_model_tags,
+ self._input_saved_model_dir)
+
+ def _gather_names(tensor_info):
+ """Get the node names from a TensorInfo."""
+ return set([tensor_info[key].name.split(":")[0] for key in tensor_info])
+
+ # Get input and outputs from all SignatureDef.
+ output_node_names = set()
+ for key in input_meta_graph_def.signature_def:
+ signature_def = input_meta_graph_def.signature_def[key]
+ output_node_names.update(_gather_names(signature_def.inputs))
+ output_node_names.update(_gather_names(signature_def.outputs))
+
+ # Freeze the variables in the SavedModel graph and copy the frozen
+ # graph over.
+ frozen_graph_def = graph_util.convert_variables_to_constants(
+ sess, sess.graph.as_graph_def(add_shapes=True),
+ list(output_node_names))
+ self._grappler_meta_graph_def = meta_graph_pb2.MetaGraphDef()
+ self._grappler_meta_graph_def.graph_def.CopyFrom(frozen_graph_def)
+
+ # Copy the collections that are not variables.
+ for key in input_meta_graph_def.collection_def:
+ # TODO(laigd): currently we use the collection key to filter out
+ # collections that depend on variable ops, but this may miss some
+ # other user-defined collections. A better way would be to use
+ # CollectionDef::NodeList for the filtering.
+ if key not in [
+ "variables", "local_variables", "model_variables",
+ "trainable_variables", "train_op", "table_initializer"
+ ]:
+ self._grappler_meta_graph_def.collection_def[key].CopyFrom(
+ input_meta_graph_def.collection_def[key])
+
+ # Copy other information.
+ self._grappler_meta_graph_def.meta_info_def.CopyFrom(
+ input_meta_graph_def.meta_info_def)
+ for key in input_meta_graph_def.signature_def:
+ self._grappler_meta_graph_def.signature_def[key].CopyFrom(
+ input_meta_graph_def.signature_def[key])
+ # TODO(laigd): maybe add back AssetFileDef.
+
+ self._run_conversion()
+
+ def convert(self):
+ """Run the conversion.
+
+ Returns:
+ The converted GraphDef.
+ """
+ assert not self._converted
+
+ if self._input_graph_def:
+ self._convert_graph_def()
+ else:
+ self._convert_saved_model()
+ return self._converted_graph_def
+
+ def save(self, output_saved_model_dir):
+ """Save the converted graph as a SavedModel.
+
+ Args:
+ output_saved_model_dir: construct a SavedModel using the converted
+ GraphDef and save it to the specified directory. This option only works
+ when the input graph is loaded from a SavedModel, i.e. when
+ input_saved_model_dir is specified and input_graph_def is None in
+ __init__().
+
+ Raises:
+ ValueError: if the input to the converter is a GraphDef instead of a
+ SavedModel.
+ """
+ assert self._converted
+
+ if self._input_graph_def:
+ raise ValueError(
+ "Not able to save to a SavedModel since input is a GraphDef")
+
+ # Write the transformed graphdef as SavedModel.
+ saved_model_builder = builder.SavedModelBuilder(output_saved_model_dir)
+ with ops.Graph().as_default():
+ importer.import_graph_def(self._converted_graph_def, name="")
+ # We don't use any specific converter here.
+ with session.Session(config=self._session_config) as sess:
+ saved_model_builder.add_meta_graph_and_variables(
+ sess,
+ self._input_saved_model_tags,
+ signature_def_map=self._grappler_meta_graph_def.signature_def)
+ # Ignore other meta graphs from the input SavedModel.
+ saved_model_builder.save()
+
+
+class TrtPrecisionMode(object):
+ FP32 = "FP32"
+ FP16 = "FP16"
+ INT8 = "INT8"
+
+ @staticmethod
+ def supported_precision_modes():
+ return [TrtPrecisionMode.FP32, TrtPrecisionMode.FP16, TrtPrecisionMode.INT8]
+
+
+# Use a large enough number as the default max_workspace_size for TRT engines,
+# so it can produce reasonable performance results with the default.
+DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES = 1 << 30
+
+
+class TrtGraphConverter(GraphConverter):
+ """A GraphConverter for TRT transformation."""
+
+ _TRT_CALIBRATION_RESOURCE_CONTAINER_NAME = "TF_TRT_Calibration"
+
+ @classmethod
+ def get_tensorrt_rewriter_config(
+ cls,
+ rewriter_config_template=None,
+ max_batch_size=1,
+ max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
+ precision_mode=TrtPrecisionMode.FP32,
+ minimum_segment_size=3,
+ is_dynamic_op=False,
+ maximum_cached_engines=1,
+ cached_engine_batches=None,
+ use_calibration=True):
+ """Returns a RewriterConfig proto for TRT transformation.
+
+ Args:
+ rewriter_config_template: a template RewriterConfig proto used to create a
+ TRT-enabled RewriterConfig. If None, it will use a default one.
+ max_batch_size: max size for the input batch
+ max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
+ engine can use at execution time. This corresponds to the
+ 'workspaceSize'
+ parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
+ precision_mode: one of TrtPrecisionMode.supported_precision_modes().
+ minimum_segment_size: the minimum number of nodes required for a subgraph
+ to be replaced by TRTEngineOp.
+ is_dynamic_op: whether to generate dynamic TRT ops which will build the
+ TRT network and engine at run time.
+ maximum_cached_engines: max number of cached TRT engines in dynamic TRT
+ ops. If the number of cached engines is already at max but none of them
+ can serve the input, the TRTEngineOp will fall back to run the TF
+ function based on which the TRTEngineOp is created.
+ cached_engine_batches: a list of batch sizes used to create cached
+ engines, only used when is_dynamic_op is True. The length of the list
+ should be <= maximum_cached_engines, and the dynamic TRT op will use
+ this list to determine the batch sizes of the cached engines, instead of
+ making the decision on the fly. This is useful when we know the most
+ common batch size(s) the application is going to generate.
+ use_calibration: this argument is ignored if precision_mode is not INT8.
+ If set to True, a calibration graph will be created to calibrate the
+ missing ranges. The calibration graph must be converted to an inference
+ graph using calib_graph_to_infer_graph() after running calibration. if
+ set to False, quantization nodes will be expected for every tensor in
+ the graph (exlcuding those which will be fused). If a range is missing,
+ an error will occur. Please note that accuracy may be negatively
+ affected if there is a mismatch between which tensors TRT quantizes and
+ which tensors were trained with fake quantization.
+
+ Returns:
+ A RewriterConfig proto which sets a TensorRTOptimizer to run Grappler.
+
+ Raises:
+ TypeError: if any of the parameters are of unexpected type.
+ ValueError: if any of the parameters are of unexpected value.
+ """
+ if rewriter_config_template is not None and not isinstance(
+ rewriter_config_template, rewriter_config_pb2.RewriterConfig):
+ raise TypeError(
+ "rewriter_config_template should be a RewriterConfig proto.")
+
+ rewriter_config_with_trt = rewriter_config_pb2.RewriterConfig()
+ if rewriter_config_template is None:
+ # Layout optimizer may add Const nodes followed by Reshape nodes, thus we
+ # need to run constant folding again.
+ rewriter_config_with_trt.optimizers.extend(
+ ["constfold", "layout", "constfold"])
+ rewriter_config_with_trt.meta_optimizer_iterations = (
+ rewriter_config_pb2.RewriterConfig.ONE)
+ else:
+ rewriter_config_with_trt.CopyFrom(rewriter_config_template)
+
+ optimizer = rewriter_config_with_trt.custom_optimizers.add()
+ optimizer.name = "TensorRTOptimizer"
+ optimizer.parameter_map["minimum_segment_size"].i = minimum_segment_size
+ optimizer.parameter_map["max_batch_size"].i = max_batch_size
+ optimizer.parameter_map["is_dynamic_op"].b = is_dynamic_op
+ optimizer.parameter_map[
+ "max_workspace_size_bytes"].i = max_workspace_size_bytes
+ optimizer.parameter_map["precision_mode"].s = _to_bytes(precision_mode)
+ optimizer.parameter_map["maximum_cached_engines"].i = maximum_cached_engines
+ if cached_engine_batches:
+ optimizer.parameter_map["cached_engine_batches"].list.i.extend(
+ cached_engine_batches)
+ optimizer.parameter_map["use_calibration"].b = use_calibration
+ return rewriter_config_with_trt
+
+ def __init__(self,
+ input_saved_model_dir=None,
+ input_saved_model_tags=None,
+ input_graph_def=None,
+ nodes_blacklist=None,
+ session_config=None,
+ max_batch_size=1,
+ max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
+ precision_mode=TrtPrecisionMode.FP32,
+ minimum_segment_size=3,
+ is_dynamic_op=False,
+ maximum_cached_engines=1,
+ cached_engine_batches=None,
+ use_calibration=True):
+ """Initialize the converter.
+
+ Args:
+ input_saved_model_dir: the directory to load the SavedModel which contains
+ the input graph to transforms. Used only when input_graph_def is None.
+ input_saved_model_tags: list of tags to load the SavedModel.
+ input_graph_def: a GraphDef object containing a model to be transformed.
+ If set to None, the graph will be read from the SavedModel loaded from
+ input_saved_model_dir.
+ nodes_blacklist: list of node names to prevent the converter from
+ touching. Only used when input_graph_def is not None.
+ session_config: the ConfigProto used to create a Session. It's also used
+ as a template to create a TRT-enabled ConfigProto for conversion. If not
+ specified, a default ConfigProto will be used.
+ max_batch_size: max size for the input batch.
+ max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
+ engine can use at execution time. This corresponds to the
+ 'workspaceSize'
+ parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
+ precision_mode: one of TrtPrecisionMode.supported_precision_modes().
+ minimum_segment_size: the minimum number of nodes required for a subgraph
+ to be replaced by TRTEngineOp.
+ is_dynamic_op: whether to generate dynamic TRT ops which will build the
+ TRT network and engine at run time.
+ maximum_cached_engines: max number of cached TRT engines in dynamic TRT
+ ops. If the number of cached engines is already at max but none of them
+ can serve the input, the TRTEngineOp will fall back to run the TF
+ function based on which the TRTEngineOp is created.
+ cached_engine_batches: a list of batch sizes used to create cached
+ engines, only used when is_dynamic_op is True. The length of the list
+ should be <= maximum_cached_engines, and the dynamic TRT op will use
+ this list to determine the batch sizes of the cached engines, instead of
+ making the decision on the fly. This is useful when we know the most
+ common batch size(s) the application is going to generate.
+ use_calibration: this argument is ignored if precision_mode is not INT8.
+ If set to True, a calibration graph will be created to calibrate the
+ missing ranges. The calibration graph must be converted to an inference
+ graph using calib_graph_to_infer_graph() after running calibration. if
+ set to False, quantization nodes will be expected for every tensor in
+ the graph (exlcuding those which will be fused). If a range is missing,
+ an error will occur. Please note that accuracy may be negatively
+ affected if there is a mismatch between which tensors TRT quantizes and
+ which tensors were trained with fake quantization.
+
+ Raises:
+ ValueError: if the combination of the parameters is invalid.
+ RuntimeError: if the TensorRT library version is incompatible.
+ """
+ super(TrtGraphConverter, self).__init__(
+ input_saved_model_dir=input_saved_model_dir,
+ input_saved_model_tags=input_saved_model_tags,
+ input_graph_def=input_graph_def,
+ nodes_blacklist=nodes_blacklist,
+ session_config=session_config)
+
+ # Check compatibility of TensorRT version.
+ compiled_version = get_linked_tensorrt_version()
+ loaded_version = get_loaded_tensorrt_version()
+ version_mismatch = False
+ if loaded_version[0] < compiled_version[0]:
+ tf_logging.error(
+ "TensorRT version mismatch. Tensorflow was compiled against " +
+ "TensorRT %s but library loaded from environment is TensorRT %s" %
+ (".".join([str(x) for x in compiled_version]),
+ ".".join([str(x) for x in loaded_version])) +
+ ". Please make sure that correct version of TensorRT " +
+ "is available in the system and added to ldconfig or LD_LIBRARY_PATH")
+ raise RuntimeError("Incompatible TensorRT library version")
+ for i in zip(loaded_version, compiled_version):
+ if i[0] != i[1]:
+ tf_logging.warn("TensorRT mismatch. Compiled against version " +
+ "%s, but loaded %s. Things may not work" %
+ (".".join([str(x) for x in compiled_version]),
+ ".".join([str(x) for x in loaded_version])))
+ version_mismatch = True
+ break
+ if not version_mismatch:
+ tf_logging.info("Running against TensorRT version %s" % ".".join(
+ [str(x) for x in loaded_version]))
+
+ # Check input arguments.
+ if precision_mode.upper() not in TrtPrecisionMode.supported_precision_modes(
+ ):
+ raise ValueError(("precision mode '{}' is not supported."
+ "It should be one of {}").format(
+ precision_mode,
+ TrtPrecisionMode.supported_precision_modes))
+
+ if cached_engine_batches:
+ if not isinstance(cached_engine_batches, list):
+ raise TypeError("cached_engine_batches should be a list.")
+ if len(cached_engine_batches) > maximum_cached_engines:
+ raise ValueError("cached_engine_batches should not contain more than "
+ "maximum_cached_engines items.")
+
+ # TODO(laigd):
+ # - Get rid of is_dynamic_op option, it should always be True, and it should
+ # accept N shapes as input.
+ # - Verify in int8 mode that maximum_cached_engines and
+ # cached_engine_batches are set appropriately.
+ # - If it fails to build the int8 engine it should return error.
+ self._max_batch_size = max_batch_size
+ self._max_workspace_size_bytes = max_workspace_size_bytes
+ self._precision_mode = precision_mode
+ self._minimum_segment_size = minimum_segment_size
+ self._is_dynamic_op = is_dynamic_op
+ self._maximum_cached_engines = maximum_cached_engines
+ self._cached_engine_batches = cached_engine_batches
+ self._use_calibration = use_calibration
+
+ def get_rewriter_config(self, rewriter_config_template=None):
+ return TrtGraphConverter.get_tensorrt_rewriter_config(
+ rewriter_config_template,
+ max_batch_size=self._max_batch_size,
+ max_workspace_size_bytes=self._max_workspace_size_bytes,
+ precision_mode=self._precision_mode,
+ minimum_segment_size=self._minimum_segment_size,
+ is_dynamic_op=self._is_dynamic_op,
+ maximum_cached_engines=self._maximum_cached_engines,
+ cached_engine_batches=self._cached_engine_batches,
+ use_calibration=self._use_calibration)
+
+
+def create_inference_graph(
+ input_graph_def,
+ outputs,
+ max_batch_size=1,
+ max_workspace_size_bytes=DEFAULT_TRT_MAX_WORKSPACE_SIZE_BYTES,
+ precision_mode=TrtPrecisionMode.FP32,
+ minimum_segment_size=3,
+ is_dynamic_op=False,
+ maximum_cached_engines=1,
+ cached_engine_batches=None,
+ use_calibration=True,
+ input_saved_model_dir=None,
+ input_saved_model_tags=None,
+ output_saved_model_dir=None,
+ session_config=None):
+ """Python wrapper for the TRT transformation.
+
+ Args:
+ input_graph_def: a GraphDef object containing a model to be transformed. If
+ set to None, the graph will be read from the SavedModel loaded from
+ input_saved_model_dir.
+ outputs: list of tensors or node names for the model outputs. Only used when
+ input_graph_def is not None.
+ max_batch_size: max size for the input batch.
+ max_workspace_size_bytes: the maximum GPU temporary memory which the TRT
+ engine can use at execution time. This corresponds to the 'workspaceSize'
+ parameter of nvinfer1::IBuilder::setMaxWorkspaceSize().
+ precision_mode: one of TrtPrecisionMode.supported_precision_modes().
+ minimum_segment_size: the minimum number of nodes required for a subgraph to
+ be replaced by TRTEngineOp.
+ is_dynamic_op: whether to generate dynamic TRT ops which will build the TRT
+ network and engine at run time.
+ maximum_cached_engines: max number of cached TRT engines in dynamic TRT ops.
+ If the number of cached engines is already at max but none of them can
+ serve the input, the TRTEngineOp will fall back to run the TF function
+ based on which the TRTEngineOp is created.
+ cached_engine_batches: a list of batch sizes used to create cached engines,
+ only used when is_dynamic_op is True. The length of the list should be <=
+ maximum_cached_engines, and the dynamic TRT op will use this list to
+ determine the batch sizes of the cached engines, instead of making the
+ decision on the fly. This is useful when we know the most common batch
+ size(s) the application is going to generate.
+ use_calibration: this argument is ignored if precision_mode is not INT8. If
+ set to True, a calibration graph will be created to calibrate the missing
+ ranges. The calibration graph must be converted to an inference graph
+ using calib_graph_to_infer_graph() after running calibration. if set to
+ False, quantization nodes will be expected for every tensor in the graph
+ (exlcuding those which will be fused). If a range is missing, an error
+ will occur. Please note that accuracy may be negatively affected if there
+ is a mismatch between which tensors TRT quantizes and which tensors were
+ trained with fake quantization.
+ input_saved_model_dir: the directory to load the SavedModel which contains
+ the input graph to transforms. Used only when input_graph_def is None.
+ input_saved_model_tags: list of tags to load the SavedModel.
+ output_saved_model_dir: if not None, construct a SavedModel using the
+ returned GraphDef and save it to the specified directory. This option only
+ works when the input graph is loaded from a SavedModel, i.e. when
+ input_saved_model_dir is specified and input_graph_def is None.
+ session_config: the ConfigProto used to create a Session. It's also used as
+ a template to create a TRT-enabled ConfigProto for conversion. If not
+ specified, a default ConfigProto will be used.
+
+ Returns:
+ A GraphDef transformed from input_graph_def (or the SavedModel graph def
+ loaded from input_saved_model_dir, if input_graph_def is not present), where
+ all TRT compatible subgraphs are replaced with TRTEngineOps, and a TF
+ function is added for each of the subgraphs.
+
+ If is_dynamic_op is True, each TRTEngineOp will contain a serialized
+ subgraph GraphDef, which will be converted to a TRT engine at execution time
+ and the TRT engine will be cached for future usage. A new TRT engine will be
+ created each time when none of the cached engines match the input shapes. If
+ it fails to execute the TRT engine or the number of cached engines reaches
+ maximum_cached_engines, the op will fall back to call the corresponding TF
+ function.
+
+ If is_dynamic_op is False, each TRTEngineOp will contain a serialized TRT
+ engine created from the corresponding subgraph. No more engines will be
+ created on the fly, and the op will fall back to call the corresponding TF
+ function when it fails to execute the engine.
+
+ Raises:
+ ValueError: if the combination of the parameters is invalid.
+ RuntimeError: if the TensorRT library version is incompatible.
+ """
+ trt_converter = TrtGraphConverter(
+ input_saved_model_dir=input_saved_model_dir,
+ input_saved_model_tags=input_saved_model_tags,
+ input_graph_def=input_graph_def,
+ nodes_blacklist=outputs,
+ session_config=session_config,
+ max_batch_size=max_batch_size,
+ max_workspace_size_bytes=max_workspace_size_bytes,
+ precision_mode=precision_mode,
+ minimum_segment_size=minimum_segment_size,
+ is_dynamic_op=is_dynamic_op,
+ maximum_cached_engines=maximum_cached_engines,
+ cached_engine_batches=cached_engine_batches,
+ use_calibration=use_calibration)
+ converted_graph_def = trt_converter.convert()
+ if output_saved_model_dir:
+ trt_converter.save(output_saved_model_dir)
+ return converted_graph_def
+
+
+def calib_graph_to_infer_graph(calibration_graph_def, is_dynamic_op=False):
+ """Convert an existing calibration graph to inference graph.
+
+ Args:
+ calibration_graph_def: the calibration GraphDef object with calibration data
+ is_dynamic_op: whether to create dynamic static engines from calibration
+
+ Returns:
+ New GraphDef with TRTEngineOps placed in graph replacing calibration nodes.
+ Raises:
+ RuntimeError: if the returned status message is malformed.
+ """
+
+ is_calib_graph = False
+ for n in calibration_graph_def.node:
+ if n.op == "TRTEngineOp":
+ is_calib_graph = is_calib_graph or not n.attr["calibration_data"].s
+ if not is_calib_graph:
+ tf_logging.error(
+ "Not a calib graph. Doesn't seem to contain any calibration nodes.")
+ return None
+ graph_str = calibration_graph_def.SerializeToString()
+ out = calib_convert(graph_str, is_dynamic_op)
+ status = _to_string(out[0])
+ output_graph_def_string = out[1]
+ del graph_str # Save some memory
+ if len(status) < 2:
+ raise _impl.UnknownError(None, None, status)
+ if status[:2] != "OK":
+ msg = status.split(";")
+ if len(msg) == 1:
+ raise RuntimeError("Status message is malformed {}".format(status))
+ # pylint: disable=protected-access
+ raise _impl._make_specific_exception(None, None, ";".join(msg[1:]),
+ int(msg[0]))
+ # pylint: enable=protected-access
+ output_graph_def = graph_pb2.GraphDef()
+ output_graph_def.ParseFromString(output_graph_def_string)
+ del output_graph_def_string # Save some memory
+ return output_graph_def
diff --git a/tensorflow/contrib/tensorrt/python/trt_convert_test.py b/tensorflow/python/compiler/tensorrt/trt_convert_test.py
similarity index 94%
rename from tensorflow/contrib/tensorrt/python/trt_convert_test.py
rename to tensorflow/python/compiler/tensorrt/trt_convert_test.py
index abd822c..0dbc5c1 100644
--- a/tensorflow/contrib/tensorrt/python/trt_convert_test.py
+++ b/tensorflow/python/compiler/tensorrt/trt_convert_test.py
@@ -23,10 +23,10 @@
# pylint: disable=unused-import
from tensorflow.compiler.tf2tensorrt.python.ops import trt_ops
# pylint: enable=unused-import
-from tensorflow.contrib.tensorrt.python import trt_convert
from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.compiler.tensorrt import trt_convert
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import graph_util
from tensorflow.python.framework import importer
@@ -47,10 +47,14 @@
class TrtConvertTest(test_util.TensorFlowTestCase):
"""Class to test Tensorflow-TensorRT integration python API."""
+ # Use a small max_workspace_size for tests so they don't consume too much GPU
+ # memory.
+ _TRT_MAX_WORKSPACE_SIZE_BYTES = 2 << 20
+
def testGetTensorrtRewriterConfig(self):
- """Test case for trt_convert.get_tensorrt_rewriter_config()."""
- rewriter_cfg = trt_convert.get_tensorrt_rewriter_config(
- rewriter_config=None,
+ """Test case for TrtGraphConverter.get_tensorrt_rewriter_config()."""
+ rewriter_cfg = trt_convert.TrtGraphConverter.get_tensorrt_rewriter_config(
+ rewriter_config_template=None,
max_batch_size=128,
max_workspace_size_bytes=1234,
precision_mode="INT8",
@@ -147,6 +151,7 @@
input_graph_def = None if input_saved_model_dir else self._GetGraphDef()
output_graph_def = trt_convert.create_inference_graph(
input_graph_def, ["output"],
+ max_workspace_size_bytes=TrtConvertTest._TRT_MAX_WORKSPACE_SIZE_BYTES,
input_saved_model_dir=input_saved_model_dir,
output_saved_model_dir=output_saved_model_dir,
session_config=self._GetConfigProto())
@@ -200,6 +205,7 @@
output_graph_def = trt_convert.create_inference_graph(
self._GetGraphDef(), ["output"],
minimum_segment_size=5,
+ max_workspace_size_bytes=TrtConvertTest._TRT_MAX_WORKSPACE_SIZE_BYTES,
is_dynamic_op=False)
node_name_to_op = {node.name: node.op for node in output_graph_def.node}
self.assertEqual({
@@ -223,6 +229,7 @@
output_graph_def = trt_convert.create_inference_graph(
None,
None,
+ max_workspace_size_bytes=TrtConvertTest._TRT_MAX_WORKSPACE_SIZE_BYTES,
is_dynamic_op=True,
maximum_cached_engines=2,
input_saved_model_dir=input_saved_model_dir,
@@ -266,6 +273,7 @@
None,
None,
max_batch_size=1,
+ max_workspace_size_bytes=TrtConvertTest._TRT_MAX_WORKSPACE_SIZE_BYTES,
is_dynamic_op=False,
maximum_cached_engines=2, # This is noop, added just for testing.
input_saved_model_dir=input_saved_model_dir,
diff --git a/tensorflow/python/data/benchmarks/BUILD b/tensorflow/python/data/benchmarks/BUILD
index 739e5ba..0314761 100644
--- a/tensorflow/python/data/benchmarks/BUILD
+++ b/tensorflow/python/data/benchmarks/BUILD
@@ -17,15 +17,23 @@
],
)
+py_library(
+ name = "benchmark_base",
+ srcs = ["benchmark_base.py"],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:session",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
py_test(
name = "batch_benchmark",
srcs = ["batch_benchmark.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:session",
+ ":benchmark_base",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
@@ -37,12 +45,8 @@
srcs = ["filter_benchmark.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:session",
+ ":benchmark_base",
"//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
],
)
@@ -51,9 +55,7 @@
srcs = ["from_tensor_slices_benchmark.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python:session",
+ ":benchmark_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -64,6 +66,7 @@
srcs = ["list_files_benchmark.py"],
srcs_version = "PY2AND3",
deps = [
+ ":benchmark_base",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
@@ -78,11 +81,8 @@
srcs = ["map_benchmark.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:session",
+ ":benchmark_base",
"//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
],
)
@@ -91,8 +91,7 @@
srcs = ["range_benchmark.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:session",
+ ":benchmark_base",
"//tensorflow/python/data/ops:dataset_ops",
],
)
diff --git a/tensorflow/python/data/benchmarks/batch_benchmark.py b/tensorflow/python/data/benchmarks/batch_benchmark.py
index 0ccf5c5..8cad912 100644
--- a/tensorflow/python/data/benchmarks/batch_benchmark.py
+++ b/tensorflow/python/data/benchmarks/batch_benchmark.py
@@ -17,70 +17,37 @@
from __future__ import division
from __future__ import print_function
-import time
-
import numpy as np
-from tensorflow.python.client import session
+from tensorflow.python.data.benchmarks import benchmark_base
from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import test
-# TODO(b/119837791): Add eager benchmarks.
-class BatchBenchmark(test.Benchmark):
+class BatchBenchmark(benchmark_base.DatasetBenchmarkBase):
"""Benchmarks for `tf.data.Dataset.batch()`."""
- def benchmarkBatchSparse(self):
+ def benchmark_batch_sparse(self):
non_zeros_per_row_values = [0, 1, 5, 10, 100]
batch_size_values = [1, 32, 64, 128, 1024]
- sparse_placeholder = array_ops.sparse_placeholder(dtype=dtypes.int64)
- batch_size_placeholder = array_ops.placeholder(dtype=dtypes.int64, shape=[])
-
- dataset = dataset_ops.Dataset.from_tensors(sparse_placeholder).repeat(
- ).batch(batch_size_placeholder)
- options = dataset_ops.Options()
- options.experimental_optimization.apply_default_optimizations = False
- dataset = dataset.with_options(options)
- iterator = dataset_ops.make_initializable_iterator(dataset)
- next_element = iterator.get_next()
-
for non_zeros_per_row in non_zeros_per_row_values:
- sparse_value = sparse_tensor.SparseTensorValue(
+ tensor = sparse_tensor.SparseTensor(
indices=np.arange(non_zeros_per_row, dtype=np.int64)[:, np.newaxis],
values=np.arange(non_zeros_per_row, dtype=np.int64),
dense_shape=[1000])
for batch_size in batch_size_values:
-
- with session.Session() as sess:
- sess.run(iterator.initializer, feed_dict={
- sparse_placeholder: sparse_value,
- batch_size_placeholder: batch_size})
- # Run five steps to warm up the session caches before taking the
- # first measurement.
- for _ in range(5):
- sess.run(next_element.indices.op)
- deltas = []
- for _ in range(100):
- start = time.time()
- for _ in range(100):
- sess.run(next_element.indices.op)
- end = time.time()
- deltas.append(end - start)
-
- median_wall_time = np.median(deltas) / 100.0
-
- self.report_benchmark(
- iters=10000,
- wall_time=median_wall_time,
- name="sparse_num_elements_%d_batch_size_%d" %
- (non_zeros_per_row, batch_size))
+ dataset = dataset_ops.Dataset.from_tensors(tensor).repeat().batch(
+ batch_size)
+ self.run_and_report_benchmark(
+ dataset,
+ num_elements=100000 // batch_size,
+ iters=1,
+ name="sparse_num_elements_%d_batch_size_%d" % (non_zeros_per_row,
+ batch_size))
if __name__ == "__main__":
- test.main()
+ benchmark_base.test.main()
diff --git a/tensorflow/python/data/benchmarks/benchmark_base.py b/tensorflow/python/data/benchmarks/benchmark_base.py
new file mode 100644
index 0000000..47f992d
--- /dev/null
+++ b/tensorflow/python/data/benchmarks/benchmark_base.py
@@ -0,0 +1,92 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test utilities for tf.data benchmarking functionality."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+import numpy as np
+
+from tensorflow.python.client import session
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.platform import test
+
+
+# TODO(b/119837791): Add eager benchmarks.
+class DatasetBenchmarkBase(test.Benchmark):
+ """Base class for dataset benchmarks."""
+
+ def run_benchmark(self, dataset, num_elements, iters=1):
+ """Benchmarks the dataset.
+
+ Runs the dataset `iters` times. In each iteration, the benchmark measures
+ the time it takes to go through `num_elements` elements of the dataset.
+
+ Args:
+ dataset: Dataset to benchmark.
+ num_elements: Number of dataset elements to iterate through each benchmark
+ iteration.
+ iters: Number of times to repeat the timing.
+
+ Returns:
+ A float, representing the per-element wall time of the dataset in seconds.
+ This is the median time (with respect to `iters`) it takes for the dataset
+ to go through `num_elements` elements, divided by `num_elements.`
+ """
+ options = dataset_ops.Options()
+ options.experimental_optimization.apply_default_optimizations = False
+ dataset = dataset.with_options(options)
+ # NOTE: We use `dataset.skip()` to perform the iterations in C++, avoiding
+ # the overhead of multiple `session.run()` calls. Note that this relies on
+ # the underlying implementation of `skip`: if it is optimized in the future,
+ # we will have to change this code.
+ dataset = dataset.skip(num_elements - 1)
+ iterator = dataset_ops.make_initializable_iterator(dataset)
+ next_element = iterator.get_next()
+ next_element = nest.flatten(next_element)[0]
+
+ deltas = []
+ for _ in range(iters):
+ with session.Session() as sess:
+ # Run once to warm up the session caches.
+ sess.run(iterator.initializer)
+ sess.run(next_element)
+
+ sess.run(iterator.initializer)
+ start = time.time()
+ sess.run(next_element.op)
+ end = time.time()
+ deltas.append(end - start)
+ return np.median(deltas) / float(num_elements)
+
+ def run_and_report_benchmark(self,
+ dataset,
+ num_elements,
+ name,
+ iters=5,
+ extras=None):
+ # Measure the per-element wall time.
+ wall_time = self.run_benchmark(dataset, num_elements, iters)
+
+ if extras is None:
+ extras = {}
+ extras["elements_per_second"] = 1 / wall_time
+ extras["num_elements"] = num_elements
+ # 'mode' represents the mechanism used for iterating over dataset elements.
+ name = "%s_mode_cpp" % name
+ self.report_benchmark(
+ wall_time=wall_time, iters=iters, name=name, extras=extras)
diff --git a/tensorflow/python/data/benchmarks/filter_benchmark.py b/tensorflow/python/data/benchmarks/filter_benchmark.py
index e0ecf19..eb47b40 100644
--- a/tensorflow/python/data/benchmarks/filter_benchmark.py
+++ b/tensorflow/python/data/benchmarks/filter_benchmark.py
@@ -17,51 +17,26 @@
from __future__ import division
from __future__ import print_function
-import time
-
-import numpy as np
-
-from tensorflow.python.client import session
+from tensorflow.python.data.benchmarks import benchmark_base
from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
-from tensorflow.python.platform import test
# TODO(b/119837791): Add eager benchmarks.
-class FilterBenchmark(test.Benchmark):
+class FilterBenchmark(benchmark_base.DatasetBenchmarkBase):
"""Benchmarks for `tf.data.Dataset.filter()`."""
def _benchmark(self, predicate, name):
- with ops.Graph().as_default():
- dataset = (
- dataset_ops.Dataset.from_tensors(True).repeat(None).filter(predicate))
- options = dataset_ops.Options()
- options.experimental_optimization.apply_default_optimizations = False
- dataset = dataset.with_options(options)
- iterator = dataset_ops.make_one_shot_iterator(dataset)
- next_element = iterator.get_next()
+ dataset = (
+ dataset_ops.Dataset.from_tensors(True).repeat(None).filter(predicate))
+ self.run_and_report_benchmark(dataset, num_elements=100000, name=name)
- with session.Session() as sess:
- for _ in range(5):
- sess.run(next_element.op)
- deltas = []
- for _ in range(100):
- start = time.time()
- for _ in range(100):
- sess.run(next_element.op)
- end = time.time()
- deltas.append(end - start)
-
- median_wall_time = np.median(deltas) / 100
- self.report_benchmark(iters=100, wall_time=median_wall_time, name=name)
-
- def benchmarkSimpleFunction(self):
+ def benchmark_simple_function(self):
self._benchmark(array_ops.identity, "simple_function")
- def benchmarkReturnComponentOptimization(self):
+ def benchmark_return_component_optimization(self):
self._benchmark(lambda x: x, "return_component")
if __name__ == "__main__":
- test.main()
+ benchmark_base.test.main()
diff --git a/tensorflow/python/data/benchmarks/from_tensor_slices_benchmark.py b/tensorflow/python/data/benchmarks/from_tensor_slices_benchmark.py
index 4e5559d..3af174a 100644
--- a/tensorflow/python/data/benchmarks/from_tensor_slices_benchmark.py
+++ b/tensorflow/python/data/benchmarks/from_tensor_slices_benchmark.py
@@ -17,170 +17,70 @@
from __future__ import division
from __future__ import print_function
-import time
-
import numpy as np
-from tensorflow.python.client import session
+from tensorflow.python.data.benchmarks import benchmark_base
from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import errors
-from tensorflow.python.platform import test
# TODO(b/119837791): Add eager benchmarks.
-class FromTensorSlicesBenchmark(test.Benchmark):
+class FromTensorSlicesBenchmark(benchmark_base.DatasetBenchmarkBase):
"""Benchmarks for `tf.data.Dataset.from_tensor_slices()`."""
- def benchmarkSliceRepeatBatch(self):
+ def benchmark_slice_repeat_batch(self):
input_size = 10000
batch_size = 100
num_epochs = 100
+ num_elements = input_size * num_epochs // batch_size
input_data = np.random.randn(input_size)
dataset = (
- dataset_ops.Dataset.from_tensor_slices(input_data)
- .repeat(num_epochs + 1).batch(batch_size))
- options = dataset_ops.Options()
- options.experimental_optimization.apply_default_optimizations = False
- dataset = dataset.with_options(options)
- iterator = dataset_ops.make_initializable_iterator(dataset)
- next_element = iterator.get_next()
+ dataset_ops.Dataset.from_tensor_slices(input_data).repeat(
+ num_epochs).batch(batch_size))
- with session.Session() as sess:
- sess.run(iterator.initializer)
- # Run one whole epoch to burn in the computation.
- for _ in range(input_size // batch_size):
- sess.run(next_element)
- deltas = []
- try:
- while True:
- start = time.time()
- sess.run(next_element)
- deltas.append(time.time() - start)
- except errors.OutOfRangeError:
- pass
-
- median_wall_time = np.median(deltas)
- self.report_benchmark(
- iters=len(deltas),
- wall_time=median_wall_time,
+ self.run_and_report_benchmark(
+ dataset,
+ num_elements=num_elements,
name="slice_repeat_batch_input_%d_batch_%d" % (input_size, batch_size))
- def benchmarkSliceRepeatBatchCallable(self):
+ def benchmark_reshape_slice_repeat(self):
input_size = 10000
- batch_size = 100
+ reshape_dim = [100, 100]
num_epochs = 100
+ num_elements = num_epochs * reshape_dim[0]
+
input_data = np.random.randn(input_size)
dataset = (
- dataset_ops.Dataset.from_tensor_slices(input_data)
- .repeat(num_epochs + 1).batch(batch_size))
- options = dataset_ops.Options()
- options.experimental_optimization.apply_default_optimizations = False
- dataset = dataset.with_options(options)
- iterator = dataset_ops.make_initializable_iterator(dataset)
- next_element = iterator.get_next()
+ dataset_ops.Dataset.from_tensor_slices(
+ input_data.reshape(*reshape_dim)).repeat(num_epochs))
- with session.Session() as sess:
- sess.run(iterator.initializer)
- get_next_element = sess.make_callable(next_element)
- # Run one whole epoch to burn in the computation.
- for _ in range(input_size // batch_size):
- get_next_element()
- deltas = []
- try:
- while True:
- start = time.time()
- get_next_element()
- deltas.append(time.time() - start)
- except errors.OutOfRangeError:
- pass
+ self.run_and_report_benchmark(
+ dataset,
+ num_elements=num_elements,
+ name="reshape_slice_repeat_input_%d" % input_size,
+ )
- median_wall_time = np.median(deltas)
- self.report_benchmark(
- iters=len(deltas),
- wall_time=median_wall_time,
- name="slice_repeat_batch_callable_input_%d_batch_%d" %
- (input_size, batch_size))
-
- def benchmarkReshapeSliceRepeatCallable(self):
+ def benchmark_slice_batch_cache_repeat(self):
input_size = 10000
batch_size = 100
num_epochs = 100
+ num_elements = input_size * num_epochs // batch_size
input_data = np.random.randn(input_size)
dataset = (
- dataset_ops.Dataset.from_tensor_slices(input_data.reshape(100, 100))
- .repeat(num_epochs + 1))
- options = dataset_ops.Options()
- options.experimental_optimization.apply_default_optimizations = False
- dataset = dataset.with_options(options)
- iterator = dataset_ops.make_initializable_iterator(dataset)
- next_element = iterator.get_next()
+ dataset_ops.Dataset.from_tensor_slices(input_data).batch(
+ batch_size).cache().repeat(num_epochs))
- with session.Session() as sess:
- sess.run(iterator.initializer)
- get_next_element = sess.make_callable(next_element)
- # Run one whole epoch to burn in the computation.
- for _ in range(input_size // batch_size):
- get_next_element()
- deltas = []
- try:
- while True:
- start = time.time()
- get_next_element()
- deltas.append(time.time() - start)
- except errors.OutOfRangeError:
- pass
-
- median_wall_time = np.median(deltas)
- self.report_benchmark(
- iters=len(deltas),
- wall_time=median_wall_time,
- name="reshape_slice_repeat_callable_input_%d_batch_%d" %
- (input_size, batch_size))
-
- def benchmarkSliceBatchCacheRepeatCallable(self):
- input_size = 10000
- batch_size = 100
- num_epochs = 100
-
- input_data = np.random.randn(input_size)
-
- dataset = (
- dataset_ops.Dataset.from_tensor_slices(input_data).batch(batch_size)
- .cache().repeat(num_epochs + 1))
- options = dataset_ops.Options()
- options.experimental_optimization.apply_default_optimizations = False
- dataset = dataset.with_options(options)
- iterator = dataset_ops.make_initializable_iterator(dataset)
- next_element = iterator.get_next()
-
- with session.Session() as sess:
- sess.run(iterator.initializer)
- get_next_element = sess.make_callable(next_element)
- # Run one whole epoch to burn in the computation.
- for _ in range(input_size // batch_size):
- get_next_element()
- deltas = []
- try:
- while True:
- start = time.time()
- get_next_element()
- deltas.append(time.time() - start)
- except errors.OutOfRangeError:
- pass
-
- median_wall_time = np.median(deltas)
- self.report_benchmark(
- iters=len(deltas),
- wall_time=median_wall_time,
- name="slice_batch_cache_repeat_callable_input_%d_batch_%d" %
- (input_size, batch_size))
+ self.run_and_report_benchmark(
+ dataset,
+ num_elements=num_elements,
+ name="slice_batch_cache_repeat_input_%d_batch_%d" % (input_size,
+ batch_size))
if __name__ == "__main__":
- test.main()
+ benchmark_base.test.main()
diff --git a/tensorflow/python/data/benchmarks/map_benchmark.py b/tensorflow/python/data/benchmarks/map_benchmark.py
index b620eaa..75b71ff 100644
--- a/tensorflow/python/data/benchmarks/map_benchmark.py
+++ b/tensorflow/python/data/benchmarks/map_benchmark.py
@@ -17,114 +17,51 @@
from __future__ import division
from __future__ import print_function
-import time
-
-import numpy as np
-
-from tensorflow.python.client import session
+from tensorflow.python.data.benchmarks import benchmark_base
from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import ops
-from tensorflow.python.platform import test
# TODO(b/119837791): Add eager benchmarks.
-class MapBenchmark(test.Benchmark):
- """Bechmarks for `tf.data.Dataset.map()`."""
+class MapBenchmark(benchmark_base.DatasetBenchmarkBase):
+ """Benchmarks for `tf.data.Dataset.map()`."""
- def benchmarkChainOfMaps(self):
+ def benchmark_chain_of_maps(self):
+
+ def benchmark_helper(chain_length, map_fn, use_inter_op_parallelism, label):
+ dataset = dataset_ops.Dataset.from_tensors(0).repeat(None)
+ for _ in range(chain_length):
+ dataset = dataset_ops.MapDataset(
+ dataset, map_fn, use_inter_op_parallelism=use_inter_op_parallelism)
+ self.run_and_report_benchmark(
+ dataset,
+ num_elements=10000,
+ name="chain_length_%d%s" % (chain_length, label))
+
chain_lengths = [0, 1, 2, 5, 10, 20, 50]
for chain_length in chain_lengths:
- for mode in ["general", "single-threaded", "short-circuit"]:
- if mode == "general":
- map_fn = lambda x: x + 1
- use_inter_op_parallelism = True
- benchmark_label = ""
- if mode == "single-threaded":
- map_fn = lambda x: x + 1
- use_inter_op_parallelism = False
- benchmark_label = "_single_threaded"
- if mode == "short-circuit":
- map_fn = lambda x: x
- use_inter_op_parallelism = True # should not have any significance
- benchmark_label = "_short_circuit"
+ benchmark_helper(chain_length, lambda x: x + 1, True, "")
+ benchmark_helper(chain_length, lambda x: x + 1, False, "_single_threaded")
+ benchmark_helper(chain_length, lambda x: x, True, "_short_circuit")
- with ops.Graph().as_default():
- dataset = dataset_ops.Dataset.from_tensors(0).repeat(None)
- for _ in range(chain_length):
- dataset = dataset_ops.MapDataset(
- dataset,
- map_fn,
- use_inter_op_parallelism=use_inter_op_parallelism)
- options = dataset_ops.Options()
- options.experimental_optimization.apply_default_optimizations = False
- dataset = dataset.with_options(options)
- iterator = dataset_ops.make_one_shot_iterator(dataset)
- next_element = iterator.get_next()
-
- with session.Session() as sess:
- for _ in range(5):
- sess.run(next_element.op)
- deltas = []
- for _ in range(100):
- start = time.time()
- for _ in range(100):
- sess.run(next_element.op)
- end = time.time()
- deltas.append(end - start)
-
- median_wall_time = np.median(deltas) / 100
- self.report_benchmark(
- iters=1000,
- wall_time=median_wall_time,
- name="chain_length_%d%s" % (chain_length, benchmark_label))
-
- def benchmarkMapFanOut(self):
+ def benchmark_map_fan_out(self):
fan_outs = [1, 2, 5, 10, 20, 50, 100]
+
+ def benchmark_helper(fan_out, map_fn, use_inter_op_parallelism, label):
+ dataset = dataset_ops.Dataset.from_tensors(
+ tuple(0 for _ in range(fan_out))).repeat(None)
+ dataset = dataset_ops.MapDataset(
+ dataset, map_fn, use_inter_op_parallelism=use_inter_op_parallelism)
+ self.run_and_report_benchmark(
+ dataset,
+ num_elements=10000,
+ name="fan_out_%d%s" % (fan_out, label))
+
for fan_out in fan_outs:
- for mode in ["general", "single-threaded", "short-circuit"]:
- if mode == "general":
- map_fn = lambda *xs: [x + 1 for x in xs]
- use_inter_op_parallelism = True
- benchmark_label = ""
- if mode == "single-threaded":
- map_fn = lambda *xs: [x + 1 for x in xs]
- use_inter_op_parallelism = False
- benchmark_label = "_single_threaded"
- if mode == "short-circuit":
- map_fn = lambda *xs: xs
- use_inter_op_parallelism = True # should not have any significance
- benchmark_label = "_short_circuit"
-
- with ops.Graph().as_default():
- dataset = dataset_ops.Dataset.from_tensors(
- tuple(0 for _ in range(fan_out))).repeat(None)
- dataset = dataset_ops.MapDataset(
- dataset,
- map_fn,
- use_inter_op_parallelism=use_inter_op_parallelism)
- options = dataset_ops.Options()
- options.experimental_optimization.apply_default_optimizations = False
- dataset = dataset.with_options(options)
- iterator = dataset_ops.make_one_shot_iterator(dataset)
- next_element = iterator.get_next()
-
- with session.Session() as sess:
- for _ in range(5):
- sess.run(next_element[0].op)
- deltas = []
- for _ in range(100):
- start = time.time()
- for _ in range(100):
- sess.run(next_element[0].op)
- end = time.time()
- deltas.append(end - start)
-
- median_wall_time = np.median(deltas) / 100
- self.report_benchmark(
- iters=1000,
- wall_time=median_wall_time,
- name="fan_out_%d%s" % (fan_out, benchmark_label))
+ benchmark_helper(fan_out, lambda *xs: [x + 1 for x in xs], True, "")
+ benchmark_helper(fan_out, lambda *xs: [x + 1 for x in xs], False,
+ "_single_threaded")
+ benchmark_helper(fan_out, lambda *xs: xs, True, "_short_circuit")
if __name__ == "__main__":
- test.main()
+ benchmark_base.test.main()
diff --git a/tensorflow/python/data/benchmarks/range_benchmark.py b/tensorflow/python/data/benchmarks/range_benchmark.py
index 375ff33..80569e4 100644
--- a/tensorflow/python/data/benchmarks/range_benchmark.py
+++ b/tensorflow/python/data/benchmarks/range_benchmark.py
@@ -17,54 +17,26 @@
from __future__ import division
from __future__ import print_function
-import time
-
-from tensorflow.python.client import session
+from tensorflow.python.data.benchmarks import benchmark_base
from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.platform import test
-
-_NUMPY_RANDOM_SEED = 42
-class RangeBenchmark(test.Benchmark):
+class RangeBenchmark(benchmark_base.DatasetBenchmarkBase):
"""Benchmarks for `tf.data.Dataset.range()`."""
- def _benchmarkRangeHelper(self, modeling_enabled):
- num_elements = 10000000 if modeling_enabled else 50000000
-
- # Use `Dataset.skip()` and `Dataset.take()` to perform the iteration in
- # C++, and focus on the minimal overheads (excluding Python invocation
- # costs).
- dataset = dataset_ops.Dataset.range(num_elements).skip(
- num_elements - 1).take(1)
- options = dataset_ops.Options()
- options.experimental_autotune = modeling_enabled
- options.experimental_optimization.apply_default_optimizations = False
- dataset = dataset.with_options(options)
- iterator = dataset_ops.make_initializable_iterator(dataset)
- next_element = iterator.get_next()
-
- with session.Session() as sess:
- # Run once to warm up the session caches.
- sess.run(iterator.initializer)
- sess.run(next_element)
-
- # Run once for timing.
- sess.run(iterator.initializer)
- start = time.time()
- sess.run(next_element)
- end = time.time()
-
- time_per_element = (end - start) / num_elements
- self.report_benchmark(
- iters=num_elements,
- wall_time=time_per_element,
- name="modeling_%s" % ("on" if modeling_enabled else "off"))
-
- def benchmarkRange(self):
+ def benchmark_range(self):
for modeling_enabled in [False, True]:
- self._benchmarkRangeHelper(modeling_enabled)
+ num_elements = 10000000 if modeling_enabled else 50000000
+ options = dataset_ops.Options()
+ options.experimental_autotune = modeling_enabled
+ dataset = dataset_ops.Dataset.range(num_elements)
+ dataset = dataset.with_options(options)
+
+ self.run_and_report_benchmark(
+ dataset,
+ num_elements=num_elements,
+ name="modeling_%s" % ("on" if modeling_enabled else "off"))
if __name__ == "__main__":
- test.main()
+ benchmark_base.test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD
index 0481913..d0e5abc 100644
--- a/tensorflow/python/data/experimental/kernel_tests/BUILD
+++ b/tensorflow/python/data/experimental/kernel_tests/BUILD
@@ -471,6 +471,20 @@
)
py_test(
+ name = "rebatch_dataset_test",
+ size = "small",
+ srcs = ["rebatch_dataset_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ ],
+)
+
+py_test(
name = "rejection_resample_test",
size = "medium",
srcs = ["rejection_resample_test.py"],
diff --git a/tensorflow/python/data/experimental/kernel_tests/cardinality_test.py b/tensorflow/python/data/experimental/kernel_tests/cardinality_test.py
index 4a8296d..993b511 100644
--- a/tensorflow/python/data/experimental/kernel_tests/cardinality_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/cardinality_test.py
@@ -49,8 +49,7 @@
lambda: dataset_ops.Dataset.range(5).filter(lambda _: True).concatenate(
dataset_ops.Dataset.range(5)), cardinality.UNKNOWN),
("Concatenate3", lambda: dataset_ops.Dataset.range(5).repeat().
- concatenate(dataset_ops.Dataset.range(5)),
- cardinality.INFINITE),
+ concatenate(dataset_ops.Dataset.range(5)), cardinality.INFINITE),
("Concatenate4", lambda: dataset_ops.Dataset.range(5).concatenate(
dataset_ops.Dataset.range(5).filter(lambda _: True)),
cardinality.UNKNOWN),
@@ -70,8 +69,7 @@
lambda: dataset_ops.Dataset.range(5).repeat().concatenate(
dataset_ops.Dataset.range(5).repeat()), cardinality.INFINITE),
("FlatMap", lambda: dataset_ops.Dataset.range(5).flat_map(
- lambda _: dataset_ops.Dataset.from_tensors(0)),
- cardinality.UNKNOWN),
+ lambda _: dataset_ops.Dataset.from_tensors(0)), cardinality.UNKNOWN),
("Filter", lambda: dataset_ops.Dataset.range(5).filter(lambda _: True),
cardinality.UNKNOWN),
("FromTensors1", lambda: dataset_ops.Dataset.from_tensors(0), 1),
@@ -117,6 +115,13 @@
cardinality.INFINITE),
("Shuffle", lambda: dataset_ops.Dataset.range(5).shuffle(buffer_size=1),
5),
+ ("Shard1", lambda: dataset_ops.Dataset.range(5).shard(2, 0), 3),
+ ("Shard2", lambda: dataset_ops.Dataset.range(5).shard(8, 7), 0),
+ ("Shard3",
+ lambda: dataset_ops.Dataset.range(5).filter(lambda _: True).shard(2, 0),
+ cardinality.UNKNOWN),
+ ("Shard4", lambda: dataset_ops.Dataset.range(5).repeat().shard(2, 0),
+ cardinality.INFINITE),
("Skip1", lambda: dataset_ops.Dataset.range(5).skip(2), 3),
("Skip2", lambda: dataset_ops.Dataset.range(5).skip(8), 0),
("Skip3",
@@ -138,15 +143,13 @@
5),
("Zip2", lambda: dataset_ops.Dataset.zip(
(dataset_ops.Dataset.range(5), dataset_ops.Dataset.range(3))), 3),
- ("Zip3", lambda: dataset_ops.Dataset.zip(
- (dataset_ops.Dataset.range(5),
- dataset_ops.Dataset.range(3).repeat())), 5),
- ("Zip4", lambda: dataset_ops.Dataset.zip(
- (dataset_ops.Dataset.range(5).repeat(),
- dataset_ops.Dataset.range(3).repeat())), cardinality.INFINITE),
- ("Zip5", lambda: dataset_ops.Dataset.zip(
- (dataset_ops.Dataset.range(5),
- dataset_ops.Dataset.range(3).filter(lambda _: True))),
+ ("Zip3", lambda: dataset_ops.Dataset.zip((dataset_ops.Dataset.range(
+ 5), dataset_ops.Dataset.range(3).repeat())), 5),
+ ("Zip4", lambda: dataset_ops.Dataset.zip((dataset_ops.Dataset.range(
+ 5).repeat(), dataset_ops.Dataset.range(3).repeat())),
+ cardinality.INFINITE),
+ ("Zip5", lambda: dataset_ops.Dataset.zip((dataset_ops.Dataset.range(
+ 5), dataset_ops.Dataset.range(3).filter(lambda _: True))),
cardinality.UNKNOWN),
# pylint: enable=g-long-lambda
)
diff --git a/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py b/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py
index 19830a2..a48f080 100644
--- a/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py
@@ -28,6 +28,7 @@
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import data_flow_ops
@@ -35,7 +36,7 @@
from tensorflow.python.platform import test
-# TODO(b/117581999): add eager coverage.
+@test_util.run_v1_only("b/123903858: Add eager and V2 test coverage")
class MapDefunTest(test_base.DatasetTestBase):
def testMapDefunSimple(self):
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py
index ef4dbc8..d69043c 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py
@@ -26,6 +26,7 @@
from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -378,6 +379,7 @@
def map_fn(x):
# x has leading dimension 5, this will raise an error
return array_ops.gather(x, 10)
+
with self.assertRaisesRegexp(errors.InvalidArgumentError,
r"indices = 10 is not in \[0, 5\)"):
base_dataset = dataset_ops.Dataset.range(5).repeat(5).batch(
@@ -478,8 +480,7 @@
self.assertDatasetsEqual(optimized, unoptimized)
- # TODO(b/117581999): Add eager coverage for the following tests.
- def testSkipEagerOptimizationIgnoreStateful(self):
+ def testOptimizationIgnoreStateful(self):
def map_fn(x):
with ops.control_dependencies([check_ops.assert_equal(x, 0)]):
@@ -489,10 +490,13 @@
[3, 4]]).repeat(5)
unoptimized, optimized = self._get_test_datasets(
base_dataset, map_fn, expect_optimized=False)
- self.assertDatasetsRaiseSameError(
- unoptimized, optimized, errors.InvalidArgumentError,
- [("OneShotIterator", "OneShotIterator_1", 1),
- ("IteratorGetNext", "IteratorGetNext_1", 1)])
+ replacements = None
+ if not context.executing_eagerly():
+ # In graph mode, the ops have unique names.
+ replacements = [("OneShotIterator", "OneShotIterator_1", 1),
+ ("IteratorGetNext", "IteratorGetNext_1", 1)]
+ self.assertDatasetsRaiseSameError(unoptimized, optimized,
+ errors.InvalidArgumentError, replacements)
def testOptimizationIgnoreRagged(self):
# Make sure we ignore inputs that might not be uniformly sized
@@ -505,8 +509,7 @@
base_dataset, map_fn, expect_optimized=False)
self.assertDatasetsEqual(unoptimized, optimized)
- # TODO(b/117581999): Add eager coverage for the following tests.
- def testSkipEagerOptimizationIgnoreRaggedMap(self):
+ def testOptimizationIgnoreRaggedMap(self):
# Don't optimize when the output of the map fn shapes are unknown.
def map_fn(x):
return array_ops.tile(x, x)
@@ -514,10 +517,29 @@
base_dataset = dataset_ops.Dataset.range(20).batch(1, drop_remainder=True)
unoptimized, optimized = self._get_test_datasets(
base_dataset, map_fn, expect_optimized=False)
- self.assertDatasetsRaiseSameError(
- unoptimized, optimized, errors.InvalidArgumentError,
- [("OneShotIterator", "OneShotIterator_1", 1),
- ("IteratorGetNext", "IteratorGetNext_1", 1)])
+ replacements = None
+ if not context.executing_eagerly():
+ # In graph mode, the ops have unique names.
+ replacements = [("OneShotIterator", "OneShotIterator_1", 1),
+ ("IteratorGetNext", "IteratorGetNext_1", 1)]
+ self.assertDatasetsRaiseSameError(unoptimized, optimized,
+ errors.InvalidArgumentError, replacements)
+
+ def testOptimizationWithUnknownBatchShape(self):
+ tensor = sparse_tensor.SparseTensor(
+ indices=[[0, 0], [1, 2]], values=[1, 2], dense_shape=[3, 4])
+
+ # Datasets with sparse tensors have unknown output shapes.
+ base_dataset = dataset_ops.Dataset.from_tensors(tensor)
+ unoptimized = base_dataset.apply(batching.map_and_batch(lambda x: x, 2))
+ options = dataset_ops.Options()
+ options.experimental_optimization.apply_default_optimizations = False
+ unoptimized = unoptimized.with_options(options)
+
+ options = dataset_ops.Options()
+ options.experimental_optimization.map_vectorization = True
+ optimized = unoptimized.with_options(options)
+ self.assertDatasetsEqual(unoptimized, optimized)
if __name__ == "__main__":
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_test.py
index bcd027e..a85e0cf 100644
--- a/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_test.py
@@ -112,7 +112,7 @@
get_next = self.getNext(dataset)
self.evaluate(get_next())
- # TODO(b/123300735): Add eager coverage for the following tests.
+ @test_util.run_v1_only("b/123902160")
def testSkipEagerOptimizationLargeInputFromTensor(self):
input_t = array_ops.placeholder(dtypes.int32, (None, None, None))
dataset = dataset_ops.Dataset.from_tensors(input_t)
@@ -127,7 +127,7 @@
sess.run(init_op, {input_t: np.ones([512, 1024, 1025], np.int32)})
self.evaluate(get_next)
- # TODO(b/117581999): Add eager coverage for the following tests.
+ @test_util.run_v1_only("b/123902160")
def testSkipEagerOptimizationLargeInputFromTensorSlices(self):
input_t = array_ops.placeholder(dtypes.int32, (None, None, None, None))
dataset = dataset_ops.Dataset.from_tensor_slices(input_t)
@@ -219,7 +219,7 @@
self.assertDatasetProduces(dataset, expected_output=[0])
@parameterized.named_parameters(_generate_captured_refvar_test_cases())
- # Skip eager because RefVariables are not supported in eager mode.
+ @test_util.run_v1_only("RefVariables are not supported in eager mode.")
def testSkipEagerOptimizationWithCapturedRefVar(self, dataset_fn):
"""Tests that default optimizations are disabled with ref variables."""
variable = variable_scope.get_variable(
@@ -236,7 +236,7 @@
options.experimental_optimization.noop_elimination = True
options.experimental_optimization.map_and_batch_fusion = True
optimized_dataset = unoptimized_dataset.with_options(options)
- optimized_it = optimized_dataset.make_initializable_iterator()
+ optimized_it = dataset_ops.make_initializable_iterator(optimized_dataset)
self.assertGreaterEqual(len(w), 1)
expected = ("tf.data static optimizations are not compatible with "
@@ -248,7 +248,8 @@
# Check that outputs are the same in the optimized and unoptimized cases,
# when the variable value is changing.
- unoptimized_it = unoptimized_dataset.make_initializable_iterator()
+ unoptimized_it = dataset_ops.make_initializable_iterator(
+ unoptimized_dataset)
with ops.control_dependencies([assign_op]):
unoptimized_output = unoptimized_it.get_next()
optimized_output = optimized_it.get_next()
diff --git a/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py
new file mode 100644
index 0000000..0dcbd56
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py
@@ -0,0 +1,60 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the private `_RebatchDataset` transformation."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import test
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class RebatchDatasetTest(test_base.DatasetTestBase):
+
+ def testBasic(self):
+ dataset = dataset_ops.Dataset.range(1024).batch(32, drop_remainder=True)
+ rebatched_dataset = batching._RebatchDataset(dataset, num_workers=4)
+ self.assertEqual(
+ [[32]], [ts.as_list() for ts in nest.flatten(dataset.output_shapes)])
+ self.assertEqual(
+ [[8]],
+ [ts.as_list() for ts in nest.flatten(rebatched_dataset.output_shapes)])
+
+ expected_output = [[k for k in range(i, i + 8)] for i in range(0, 1024, 8)] # pylint: disable=g-complex-comprehension
+ self.assertDatasetProduces(rebatched_dataset, expected_output)
+
+ def testScalarInputError(self):
+ dataset = dataset_ops.Dataset.range(1024)
+ with self.assertRaisesRegexp(ValueError, "at least one dimension"):
+ batching._RebatchDataset(dataset, num_workers=4)
+
+ def testUnknownBatchSizeError(self):
+ dataset = dataset_ops.Dataset.range(1024).batch(32)
+ with self.assertRaisesRegexp(ValueError, "unknown batch size datasets"):
+ batching._RebatchDataset(dataset, num_workers=4)
+
+ def testNotDivisibleError(self):
+ dataset = dataset_ops.Dataset.range(1024).batch(32, drop_remainder=True)
+ with self.assertRaisesRegexp(ValueError, "not divisible by"):
+ batching._RebatchDataset(dataset, num_workers=5)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD b/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD
index 4fd2a2e..caf571e 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD
@@ -409,6 +409,24 @@
)
py_test(
+ name = "rebatch_dataset_serialization_test",
+ size = "small",
+ srcs = ["rebatch_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
name = "padded_batch_dataset_serialization_test",
size = "medium",
srcs = ["padded_batch_dataset_serialization_test.py"],
@@ -606,6 +624,24 @@
)
py_test(
+ name = "shard_dataset_serialization_test",
+ size = "medium",
+ srcs = ["shard_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
name = "shuffle_and_repeat_dataset_serialization_test",
size = "medium",
srcs = ["shuffle_and_repeat_dataset_serialization_test.py"],
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/checkpoint_input_pipeline_hook_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/checkpoint_input_pipeline_hook_test.py
index 8cc66d0..84b8e5c 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/checkpoint_input_pipeline_hook_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/checkpoint_input_pipeline_hook_test.py
@@ -19,7 +19,6 @@
from __future__ import print_function
from tensorflow.python.data.experimental.ops import iterator_ops
-from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -35,7 +34,8 @@
from tensorflow_estimator.python.estimator import model_fn
-class CheckpointInputPipelineHookTest(test_base.DatasetTestBase):
+@test_util.run_v1_only('b/123904664')
+class CheckpointInputPipelineHookTest(test.TestCase):
@staticmethod
def _model_fn(features, labels, mode, config):
@@ -69,7 +69,6 @@
def _build_iterator_saver_hook(self, est):
return iterator_ops.CheckpointInputPipelineHook(est)
- @test_util.run_deprecated_v1
def testReturnDatasetFromInputFn(self):
def _input_fn():
@@ -82,7 +81,6 @@
est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3))
- @test_util.run_deprecated_v1
def testBuildIteratorInInputFn(self):
def _input_fn():
@@ -97,7 +95,6 @@
est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3))
- @test_util.run_deprecated_v1
def testDoNotRestore(self):
def _input_fn():
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py b/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py
index bdbd870..ca45ecc 100644
--- a/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py
@@ -23,7 +23,6 @@
import numpy as np
from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops
-from tensorflow.python.data.experimental.ops.optimization_options import OptimizationOptions
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import dtypes
@@ -78,7 +77,6 @@
# NOTE: We disable all default optimizations in serialization tests in order
# to test the actual dataset in question.
options = dataset_ops.Options()
- options.experimental_optimization = OptimizationOptions()
options.experimental_optimization.apply_default_optimizations = False
def ds_fn1_no_opt():
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/rebatch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/rebatch_dataset_serialization_test.py
new file mode 100644
index 0000000..b30db58
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/rebatch_dataset_serialization_test.py
@@ -0,0 +1,41 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the _RebatchDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.platform import test
+
+
+class RebatchDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def testCore(self):
+
+ def build_dataset(num_elements, batch_size):
+ return batching._RebatchDataset(
+ dataset_ops.Dataset.range(num_elements).batch(
+ 4 * batch_size, drop_remainder=True),
+ num_workers=4)
+
+ self.run_core_tests(lambda: build_dataset(200, 10), None, 20)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/shard_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/shard_dataset_serialization_test.py
new file mode 100644
index 0000000..99674b6
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/shard_dataset_serialization_test.py
@@ -0,0 +1,42 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the ShardDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.platform import test
+
+
+class ShardDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase,
+ parameterized.TestCase):
+
+ def _build_dataset(self, num_elements, num_shards, index):
+ return dataset_ops.Dataset.range(num_elements).shard(num_shards, index)
+
+ @parameterized.parameters((10, 5, 2, 3), (10, 10, 0, 9), (100, 2, 0, 1))
+ def testCore(self, elems, num_shards, index1, index2):
+ self.run_core_tests(lambda: self._build_dataset(elems, num_shards, index1),
+ lambda: self._build_dataset(elems, num_shards, index2),
+ elems // num_shards)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/sleep_test.py b/tensorflow/python/data/experimental/kernel_tests/sleep_test.py
index a4fe847..4733c2a 100644
--- a/tensorflow/python/data/experimental/kernel_tests/sleep_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/sleep_test.py
@@ -33,6 +33,7 @@
class SleepTest(test_base.DatasetTestBase):
def testSleep(self):
+ self.skipTest("b/123597912")
sleep_microseconds = 100
dataset = dataset_ops.Dataset.range(10).apply(
sleep.sleep(sleep_microseconds))
diff --git a/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py
index 8b33055..c53ac82 100644
--- a/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py
@@ -35,7 +35,6 @@
from tensorflow.python.platform import test
-@test_util.run_all_in_graph_and_eager_modes
def function_set_stats_aggregator(dataset,
aggregator,
prefix="",
@@ -53,20 +52,21 @@
return dataset.with_options(options)
+@test_util.run_all_in_graph_and_eager_modes
@parameterized.named_parameters(
("SetStatsAggregator", function_set_stats_aggregator),
("StatsOptions", function_apply_options),
)
class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
- def testBytesProduced(self, dataset_transformation):
+ @test_util.run_v1_only("b/123901126")
+ def testSkipEagerBytesProduced(self, dataset_transformation):
aggregator = stats_aggregator.StatsAggregator()
dataset = dataset_ops.Dataset.range(100).map(
lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply(
stats_ops.bytes_produced_stats("bytes_produced"))
dataset = dataset_transformation(dataset, aggregator)
next_element = self.getNext(dataset, requires_initialization=True)
- summary_t = aggregator.get_summary()
expected_sum = 0.0
for i in range(100):
@@ -78,8 +78,7 @@
self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
- # TODO(shivaniagrawal): ntentional breaking case
- summary_str = self.evaluate(summary_t)
+ summary_str = self.evaluate(aggregator.get_summary())
self._assertSummaryHasCount(summary_str, "bytes_produced", 100.0)
self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
@@ -357,13 +356,11 @@
100.0)
+@test_util.run_all_in_graph_and_eager_modes
@parameterized.named_parameters(
- dict(
- testcase_name="SetStatsAggregator",
- dataset_transformation=function_set_stats_aggregator),
- dict(
- testcase_name="StatsOptions",
- dataset_transformation=function_apply_options))
+ ("SetStatsAggregator", function_set_stats_aggregator),
+ ("StatsOptions", function_apply_options)
+)
class FeatureStatsDatasetTest(
stats_dataset_test_base.StatsDatasetTestBase,
reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase):
diff --git a/tensorflow/python/data/experimental/kernel_tests/wrap_unwrap_test.py b/tensorflow/python/data/experimental/kernel_tests/wrap_unwrap_test.py
index a8f5050..e6e7757 100644
--- a/tensorflow/python/data/experimental/kernel_tests/wrap_unwrap_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/wrap_unwrap_test.py
@@ -31,7 +31,7 @@
def testBasic(self):
ds = dataset_ops.Dataset.range(100)
- ds_variant = ds._as_variant_tensor() # pylint: disable=protected-access
+ ds_variant = ds._variant_tensor # pylint: disable=protected-access
wrapped_variant = gen_dataset_ops.wrap_dataset_variant(ds_variant)
unwrapped_variant = gen_dataset_ops.unwrap_dataset_variant(wrapped_variant)
@@ -42,10 +42,10 @@
for i in range(100):
self.assertEqual(i, self.evaluate(get_next()))
- # TODO(b/117581999): add eager coverage when supported.
+ @test_util.run_v1_only("b/123901304")
def testSkipEagerGPU(self):
ds = dataset_ops.Dataset.range(100)
- ds_variant = ds._as_variant_tensor() # pylint: disable=protected-access
+ ds_variant = ds._variant_tensor # pylint: disable=protected-access
wrapped_variant = gen_dataset_ops.wrap_dataset_variant(ds_variant)
with ops.device("/gpu:0"):
diff --git a/tensorflow/python/data/experimental/ops/batching.py b/tensorflow/python/data/experimental/ops/batching.py
index f0cf7f0..39cb0a6 100644
--- a/tensorflow/python/data/experimental/ops/batching.py
+++ b/tensorflow/python/data/experimental/ops/batching.py
@@ -645,3 +645,34 @@
num_parallel_calls, drop_remainder)
return _apply_fn
+
+
+class _RebatchDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that divides the batch size by `num_workers`."""
+
+ def __init__(self, input_dataset, num_workers):
+ self._input_dataset = input_dataset
+ output_shapes = input_dataset.output_shapes
+ if len(output_shapes) < 1:
+ raise ValueError("Input shape should have at least one dimension.")
+ if not output_shapes.dims[0].value:
+ raise ValueError("Cannot rebatch unknown batch size datasets.")
+ if output_shapes.dims[0].value % num_workers != 0:
+ raise ValueError(
+ "First dim of input shape: %d is not divisible by num_workers: %d" %
+ (output_shapes[0], num_workers))
+ output_dims = [d for d in output_shapes.dims]
+ output_dims[0] = output_dims[0] // num_workers
+ output_shapes = tensor_shape.TensorShapeV1(output_dims)
+ self._structure = structure.convert_legacy_structure(
+ self._input_dataset.output_types, output_shapes,
+ self._input_dataset.output_classes)
+ variant_tensor = ged_ops.experimental_rebatch_dataset(
+ self._input_dataset._variant_tensor, # pylint: disable=protected-access
+ num_workers=num_workers,
+ **dataset_ops.flat_structure(self))
+ super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)
+
+ @property
+ def _element_structure(self):
+ return self._structure
diff --git a/tensorflow/python/data/experimental/ops/resampling.py b/tensorflow/python/data/experimental/ops/resampling.py
index 3a3040a..6676085 100644
--- a/tensorflow/python/data/experimental/ops/resampling.py
+++ b/tensorflow/python/data/experimental/ops/resampling.py
@@ -168,8 +168,7 @@
def _estimate_initial_dist_ds(
target_dist_t, class_values_ds, dist_estimation_batch_size=32,
smoothing_constant=10):
- num_classes = (target_dist_t.shape[0].value or
- array_ops.shape(target_dist_t)[0])
+ num_classes = (target_dist_t.shape[0] or array_ops.shape(target_dist_t)[0])
initial_examples_per_class_seen = array_ops.fill(
[num_classes], np.int64(smoothing_constant))
@@ -207,7 +206,7 @@
`[num_classes]`.
dist: The updated distribution. Type `float32`, shape `[num_classes]`.
"""
- num_classes = num_examples_per_class_seen.get_shape()[0].value
+ num_classes = num_examples_per_class_seen.get_shape()[0]
# Update the class-count based on what labels are seen in batch.
num_examples_per_class_seen = math_ops.add(
num_examples_per_class_seen, math_ops.reduce_sum(
diff --git a/tensorflow/python/data/experimental/ops/stats_aggregator.py b/tensorflow/python/data/experimental/ops/stats_aggregator.py
index 3e4c66b..0c6e686 100644
--- a/tensorflow/python/data/experimental/ops/stats_aggregator.py
+++ b/tensorflow/python/data/experimental/ops/stats_aggregator.py
@@ -44,7 +44,7 @@
dataset = ...
# Apply `StatsOptions` to associate `dataset` with `aggregator`.
- options = dataset_ops.Options()
+ options = tf.data.Options()
options.experimental_stats.aggregator = aggregator
dataset = dataset.with_options(options)
```
diff --git a/tensorflow/python/data/kernel_tests/dataset_test.py b/tensorflow/python/data/kernel_tests/dataset_test.py
index f319b24..1e764b3 100644
--- a/tensorflow/python/data/kernel_tests/dataset_test.py
+++ b/tensorflow/python/data/kernel_tests/dataset_test.py
@@ -92,15 +92,16 @@
("TFRecord", lambda: readers.TFRecordDataset(""), 1),
)
def testDatasetSimpleSourceInputs(self, dataset_fn, num_inputs=0):
- self.assertEqual(num_inputs, len(dataset_fn()._inputs()))
+ self.assertLen(dataset_fn()._inputs(), num_inputs)
+ @test_util.run_v1_only("deprecated API, no eager or V2 test coverage")
def testDatasetComplexSourceInputs(self):
dataset_fn = dataset_ops.Dataset.from_sparse_tensor_slices(
sparse_tensor.SparseTensor(
indices=np.array([[0, 0], [1, 0], [2, 0]]),
values=np.array([0, 0, 0]),
dense_shape=np.array([3, 1])))
- self.assertEqual(0, len(dataset_fn._inputs()))
+ self.assertEmpty(dataset_fn._inputs())
@parameterized.named_parameters(
("Batch",
@@ -266,27 +267,24 @@
round_trip_dataset, [self.evaluate(tf_value_fn())],
requires_initialization=True)
- # NOTE: This test is specific to graph mode and is skipped in eager mode.
- @test_util.run_deprecated_v1
+ @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage")
def testSkipEagerSameGraphErrorOneShot(self):
dataset = dataset_ops.Dataset.range(10)
with ops.Graph().as_default():
with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
dataset = dataset.batch(2)
- # NOTE: This test is specific to graph mode and is skipped in eager mode.
- @test_util.run_deprecated_v1
+ @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage")
def testSkipEagerSameGraphErrorOneShotSimple(self):
dataset = dataset_ops.Dataset.range(10)
with ops.Graph().as_default():
with test.mock.patch.object(logging, "warning") as mock_log:
- _ = dataset.make_one_shot_iterator()
+ _ = dataset_ops.make_one_shot_iterator(dataset)
self.assertRegexpMatches(
str(mock_log.call_args), "Please ensure that all datasets in the "
"pipeline are created in the same graph as the iterator.")
- # NOTE: This test is specific to graph mode and is skipped in eager mode.
- @test_util.run_deprecated_v1
+ @test_util.run_v1_only("graph mode specific, no eager or V2 test coverage")
def testSkipEagerSameGraphErrorInitializable(self):
dataset = dataset_ops.Dataset.range(10)
with ops.Graph().as_default():
diff --git a/tensorflow/python/data/kernel_tests/flat_map_test.py b/tensorflow/python/data/kernel_tests/flat_map_test.py
index ff52821..69b5fd0 100644
--- a/tensorflow/python/data/kernel_tests/flat_map_test.py
+++ b/tensorflow/python/data/kernel_tests/flat_map_test.py
@@ -65,11 +65,11 @@
repeats = [[1, 2], [3, 4], [5, 0], [1, 7]]
components = np.array(repeats, dtype=np.int64)
iterator = (
- dataset_ops.Dataset.from_tensor_slices(components)
- .flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices(x)
- .flat_map(lambda y: dataset_ops.Dataset.from_tensors(y)
- .repeat(y))).make_initializable_iterator(
- shared_name="shared_flat_map_iterator"))
+ dataset_ops.make_initializable_iterator(
+ dataset_ops.Dataset.from_tensor_slices(components).flat_map(
+ lambda x: dataset_ops.Dataset.from_tensor_slices(x).flat_map(
+ lambda y: dataset_ops.Dataset.from_tensors(y).repeat(y))),
+ shared_name="shared_flat_map_iterator"))
init_op = iterator.initializer
get_next = iterator.get_next()
diff --git a/tensorflow/python/data/kernel_tests/from_sparse_tensor_slices_test.py b/tensorflow/python/data/kernel_tests/from_sparse_tensor_slices_test.py
index 546c2fb..2ce9c9a 100644
--- a/tensorflow/python/data/kernel_tests/from_sparse_tensor_slices_test.py
+++ b/tensorflow/python/data/kernel_tests/from_sparse_tensor_slices_test.py
@@ -29,10 +29,9 @@
from tensorflow.python.platform import test
-# NOTE: deprecated method in V2, no eager coverage added.
+@test_util.run_v1_only("deprecated API, no eager or V2 test coverage")
class FromSparseTensorSlicesTest(test_base.DatasetTestBase):
- @test_util.run_deprecated_v1
def testFromSparseTensorSlices(self):
"""Test a dataset based on slices of a `tf.SparseTensor`."""
st = array_ops.sparse_placeholder(dtypes.float64)
diff --git a/tensorflow/python/data/kernel_tests/map_test.py b/tensorflow/python/data/kernel_tests/map_test.py
index 8d98d65..7a17dd8 100644
--- a/tensorflow/python/data/kernel_tests/map_test.py
+++ b/tensorflow/python/data/kernel_tests/map_test.py
@@ -25,6 +25,7 @@
import numpy as np
from tensorflow.core.framework import attr_value_pb2
+from tensorflow.core.protobuf import config_pb2
from tensorflow.python.data.experimental.ops import threading_options
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
@@ -312,8 +313,8 @@
if context.executing_eagerly():
captured_iterator = iter(dataset_ops.Dataset.range(10))
else:
- captured_iterator = dataset_ops.Dataset.range(
- 10).make_initializable_iterator()
+ captured_iterator = dataset_ops.make_initializable_iterator(
+ dataset_ops.Dataset.range(10))
ds = _build_ds(captured_iterator)
return captured_iterator, ds
@@ -350,6 +351,7 @@
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
+ @test_util.run_v1_only("b/123904513")
def testCaptureQueue(self):
elements = np.random.randint(100, size=[200])
queue = data_flow_ops.FIFOQueue(200, dtypes.int64, shapes=[])
@@ -391,6 +393,69 @@
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
+ # TODO(b/121264236): add eager mode coverage when we have mutli-device setup.
+ @test_util.run_v1_only("b/121264236")
+ def testSkipEagerCaptureConstantsWithConflictingDevices(self):
+ config = config_pb2.ConfigProto(device_count={"CPU": 3})
+ with self.cached_session(config=config):
+ with ops.device("/device:CPU:0"):
+ a = constant_op.constant(3.0)
+ with ops.device("/device:CPU:1"):
+ b = constant_op.constant(5.0)
+
+ def func(_):
+ return math_ops.add(a, b)
+
+ dataset = dataset_ops.Dataset.from_tensors(0).repeat(10).map(func)
+ expected_output = [8.0] * 10
+ self.assertDatasetProduces(dataset, expected_output=expected_output)
+
+ # TODO(b/121264236): add eager mode coverage when we have mutli-device setup.
+ @test_util.run_v1_only(
+ "defun will convert RefVariables to ResourceVariables.")
+ def testSkipEagerRefVariablesWithConflictingDevices(self):
+ config = config_pb2.ConfigProto(device_count={"CPU": 3})
+ with self.cached_session(config=config):
+
+ def func(_):
+ with ops.device("/device:CPU:0"):
+ a = variables.VariableV1(3.0)
+ with ops.device("/device:CPU:1"):
+ b = variables.VariableV1(5.0)
+ return math_ops.add(a, b)
+
+ dataset = dataset_ops.Dataset.from_tensors(0).repeat(10).map(func)
+ self.evaluate(variables.global_variables_initializer())
+ expected_output = [8.0] * 10
+ self.assertDatasetProduces(
+ dataset,
+ expected_output=expected_output,
+ requires_initialization=True)
+
+ # TODO(b/121264236): add eager mode coverage when we have mutli-device setup.
+ @test_util.run_v1_only("b/121264236")
+ def testSkipEagerResourceVariablesWithConflictingDevices(self):
+ config = config_pb2.ConfigProto(device_count={"CPU": 3})
+ with self.cached_session(config=config):
+
+ def func(_):
+ with ops.device("/device:CPU:0"):
+ a = variables.Variable(3.0)
+ with ops.device("/device:CPU:1"):
+ b = variables.Variable(5.0)
+ return math_ops.add(a, b)
+
+ # The MapDataset node ends up with two ResourceVariable inputs, one on
+ # device CPU:0 and the other on device CPU:1. The placer cannot resolve
+ # this as it cannot place the MapDatasetOp on both devices.
+ dataset = dataset_ops.Dataset.from_tensors(0).repeat(10).map(func)
+ expected_error = (
+ errors.InvalidArgumentError,
+ "Cannot place the graph because a reference or resource edge "
+ "connects colocation groups with incompatible assigned devices")
+ self.assertDatasetProduces(
+ dataset, expected_error=expected_error, requires_initialization=True)
+
def testCaptureVariable(self):
counter_var = variable_scope.get_variable(
"counter", (), dtypes.int32, use_resource=True)
@@ -639,6 +704,13 @@
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(get_next())
+ def testNestedListMapDataset(self):
+ dataset = dataset_ops.Dataset.from_tensors(
+ [0, 1, 2]).repeat(10).map(lambda a: ([a[1], a[0] + a[2]], a[1]))
+
+ expected_output = [(np.array([1, 2]), 1)] * 10
+ self.assertDatasetProduces(dataset, expected_output=expected_output)
+
def testPrefetch(self):
# We will use this event to test that `_map_py_func()` has been
# invoked a certain number of times (6 times, to be exact) after
@@ -746,6 +818,7 @@
dataset,
expected_output=[self.evaluate(_check(_sparse(i))) for i in range(10)])
+ @test_util.run_v1_only("b/123904513")
def testParallelMapOutOfRangeError(self):
def raising_py_func(i):
if i == 100:
diff --git a/tensorflow/python/data/kernel_tests/optional_test.py b/tensorflow/python/data/kernel_tests/optional_test.py
index 2269bb8..4fde0aa 100644
--- a/tensorflow/python/data/kernel_tests/optional_test.py
+++ b/tensorflow/python/data/kernel_tests/optional_test.py
@@ -329,7 +329,7 @@
if not works_on_gpu and test.is_gpu_available():
self.skipTest("Test case not yet supported on GPU.")
ds = dataset_ops.Dataset.from_tensors(np_value).repeat(3)
- iterator = ds.make_initializable_iterator()
+ iterator = dataset_ops.make_initializable_iterator(ds)
next_elem = iterator_ops.get_next_as_optional(iterator)
self.assertIsInstance(next_elem, optional_ops.Optional)
self.assertTrue(
diff --git a/tensorflow/python/data/kernel_tests/shard_test.py b/tensorflow/python/data/kernel_tests/shard_test.py
index 9285506..9fc70ff 100644
--- a/tensorflow/python/data/kernel_tests/shard_test.py
+++ b/tensorflow/python/data/kernel_tests/shard_test.py
@@ -19,11 +19,12 @@
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
-@test_util.run_all_in_graph_and_eager_modes
+@test_util.run_v1_only("deprecated API, no eager or V2 test coverage")
class ShardTest(test_base.DatasetTestBase):
def testSimpleCase(self):
@@ -41,20 +42,24 @@
self.assertDatasetProduces(dataset, expected_output=[0, 5])
def testOffsetGreaterNumShards(self):
- with self.assertRaises(ValueError):
- dataset_ops.Dataset.range(10).shard(5, 7)
+ with self.assertRaises(errors.InvalidArgumentError):
+ dataset = dataset_ops.Dataset.range(10).shard(5, 7)
+ self.evaluate(self.getNext(dataset)())
def testNegativeOffset(self):
- with self.assertRaises(ValueError):
- dataset_ops.Dataset.range(10).shard(5, -3)
+ with self.assertRaises(errors.InvalidArgumentError):
+ dataset = dataset_ops.Dataset.range(10).shard(5, -3)
+ self.evaluate(self.getNext(dataset)())
def testNegativeNumShards(self):
- with self.assertRaises(ValueError):
- dataset_ops.Dataset.range(10).shard(-3, 1)
+ with self.assertRaises(errors.InvalidArgumentError):
+ dataset = dataset_ops.Dataset.range(10).shard(-3, 1)
+ self.evaluate(self.getNext(dataset)())
def testZeroNumShards(self):
- with self.assertRaises(ValueError):
- dataset_ops.Dataset.range(10).shard(0, 1)
+ with self.assertRaises(errors.InvalidArgumentError):
+ dataset = dataset_ops.Dataset.range(10).shard(0, 1)
+ self.evaluate(self.getNext(dataset)())
def testIteratorEndsBeforeFirstElem(self):
dataset = dataset_ops.Dataset.range(1).shard(5, 2)
@@ -72,5 +77,10 @@
dataset = dataset_ops.Dataset.range(10).shard(4, 3)
self.assertDatasetProduces(dataset, expected_output=[3, 7])
+ def testNumShardsLargerThanDataset(self):
+ dataset = dataset_ops.Dataset.range(10).shard(20, 5)
+ self.assertDatasetProduces(dataset, expected_output=[5])
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/data/kernel_tests/test_base.py b/tensorflow/python/data/kernel_tests/test_base.py
index 7aa7f33..57df29e 100644
--- a/tensorflow/python/data/kernel_tests/test_base.py
+++ b/tensorflow/python/data/kernel_tests/test_base.py
@@ -19,6 +19,7 @@
import re
+from tensorflow.python import tf2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.eager import context
@@ -32,6 +33,13 @@
class DatasetTestBase(test.TestCase):
"""Base class for dataset tests."""
+ @classmethod
+ def setUpClass(cls):
+ if tf2.enabled():
+ dataset_ops.Dataset = dataset_ops.DatasetV2
+ else:
+ dataset_ops.Dataset = dataset_ops.DatasetV1
+
def assertSparseValuesEqual(self, a, b):
"""Asserts that two SparseTensors/SparseTensorValues are equal."""
self.assertAllEqual(a.indices, b.indices)
@@ -178,6 +186,8 @@
exception_class,
replacements=None):
"""Checks that datasets raise the same error on the first get_next call."""
+ if replacements is None:
+ replacements = []
next1 = self.getNext(dataset1)
next2 = self.getNext(dataset2)
try:
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 766c6d5..7cd838b 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -28,7 +28,6 @@
from tensorflow.python.compat import compat
-from tensorflow.python.data.experimental.ops import filter_for_shard_ops
from tensorflow.python.data.experimental.ops import optimization_options
from tensorflow.python.data.experimental.ops import stats_options
from tensorflow.python.data.experimental.ops import threading_options
@@ -805,6 +804,59 @@
"""
return SkipDataset(self, count)
+ def shard(self, num_shards, index):
+ """Creates a `Dataset` that includes only 1/`num_shards` of this dataset.
+
+ This dataset operator is very useful when running distributed training, as
+ it allows each worker to read a unique subset.
+
+ When reading a single input file, you can skip elements as follows:
+
+ ```python
+ d = tf.data.TFRecordDataset(input_file)
+ d = d.shard(num_workers, worker_index)
+ d = d.repeat(num_epochs)
+ d = d.shuffle(shuffle_buffer_size)
+ d = d.map(parser_fn, num_parallel_calls=num_map_threads)
+ ```
+
+ Important caveats:
+
+ - Be sure to shard before you use any randomizing operator (such as
+ shuffle).
+ - Generally it is best if the shard operator is used early in the dataset
+ pipeline. For example, when reading from a set of TFRecord files, shard
+ before converting the dataset to input samples. This avoids reading every
+ file on every worker. The following is an example of an efficient
+ sharding strategy within a complete pipeline:
+
+ ```python
+ d = Dataset.list_files(pattern)
+ d = d.shard(num_workers, worker_index)
+ d = d.repeat(num_epochs)
+ d = d.shuffle(shuffle_buffer_size)
+ d = d.interleave(tf.data.TFRecordDataset,
+ cycle_length=num_readers, block_length=1)
+ d = d.map(parser_fn, num_parallel_calls=num_map_threads)
+ ```
+
+ Args:
+ num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of
+ shards operating in parallel.
+ index: A `tf.int64` scalar `tf.Tensor`, representing the worker index.
+
+ Returns:
+ Dataset: A `Dataset`.
+
+ Raises:
+ InvalidArgumentError: if `num_shards` or `index` are illegal values.
+ Note: error checking is done on a best-effort basis, and errors aren't
+ guaranteed to be caught upon dataset creation. (e.g. providing in a
+ placeholder tensor bypasses the early checking, and will instead result
+ in an error during a session.run call.)
+ """
+ return ShardDataset(self, num_shards, index)
+
def batch(self, batch_size, drop_remainder=False):
"""Combines consecutive elements of this dataset into batches.
@@ -1100,6 +1152,18 @@
def filter(self, predicate):
"""Filters this dataset according to `predicate`.
+ ```python
+ d = tf.data.Dataset.from_tensor_slices([1, 2, 3])
+
+ d = d.filter(lambda x: x < 3) # [1, 2]
+
+ # `tf.math.equal(x, y)` is required for equality comparison
+ def filter_fn(x):
+ return tf.math.equal(x, 1)
+
+ d = d.filter(filter_fn) # [1]
+ ```
+
Args:
predicate: A function mapping a nested structure of tensors (having shapes
and types defined by `self.output_shapes` and `self.output_types`) to a
@@ -1550,60 +1614,9 @@
def skip(self, count):
return DatasetV1Adapter(super(DatasetV1, self).skip(count))
- @deprecation.deprecated(
- None, "Use `dataset.apply(tf.data.experimental.filter_for_shard(...))`.")
+ @functools.wraps(DatasetV2.shard)
def shard(self, num_shards, index):
- """Creates a `Dataset` that includes only 1/`num_shards` of this dataset.
-
- This dataset operator is very useful when running distributed training, as
- it allows each worker to read a unique subset.
-
- When reading a single input file, you can skip elements as follows:
-
- ```python
- d = tf.data.TFRecordDataset(FLAGS.input_file)
- d = d.shard(FLAGS.num_workers, FLAGS.worker_index)
- d = d.repeat(FLAGS.num_epochs)
- d = d.shuffle(FLAGS.shuffle_buffer_size)
- d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads)
- ```
-
- Important caveats:
-
- - Be sure to shard before you use any randomizing operator (such as
- shuffle).
- - Generally it is best if the shard operator is used early in the dataset
- pipeline. For example, when reading from a set of TFRecord files, shard
- before converting the dataset to input samples. This avoids reading every
- file on every worker. The following is an example of an efficient
- sharding strategy within a complete pipeline:
-
- ```python
- d = Dataset.list_files(FLAGS.pattern)
- d = d.shard(FLAGS.num_workers, FLAGS.worker_index)
- d = d.repeat(FLAGS.num_epochs)
- d = d.shuffle(FLAGS.shuffle_buffer_size)
- d = d.interleave(tf.data.TFRecordDataset,
- cycle_length=FLAGS.num_readers, block_length=1)
- d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads)
- ```
-
- Args:
- num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of
- shards operating in parallel.
- index: A `tf.int64` scalar `tf.Tensor`, representing the worker index.
-
- Returns:
- Dataset: A `Dataset`.
-
- Raises:
- ValueError: if `num_shards` or `index` are illegal values. Note: error
- checking is done on a best-effort basis, and errors aren't guaranteed
- to be caught upon dataset creation. (e.g. providing in a placeholder
- tensor bypasses the early checking, and will instead result in an error
- during a session.run call.)
- """
- return self.apply(filter_for_shard_ops.filter_for_shard(num_shards, index))
+ return DatasetV1Adapter(super(DatasetV1, self).shard(num_shards, index))
@functools.wraps(DatasetV2.batch)
def batch(self, batch_size, drop_remainder=False):
@@ -2504,6 +2517,23 @@
super(SkipDataset, self).__init__(input_dataset, variant_tensor)
+class ShardDataset(UnaryUnchangedStructureDataset):
+ """A `Dataset` for sharding its input."""
+
+ def __init__(self, input_dataset, num_shards, index):
+ """See `Dataset.shard()` for details."""
+ self._input_dataset = input_dataset
+ self._num_shards = ops.convert_to_tensor(
+ num_shards, dtype=dtypes.int64, name="num_shards")
+ self._index = ops.convert_to_tensor(index, dtype=dtypes.int64, name="index")
+ variant_tensor = gen_dataset_ops.shard_dataset(
+ input_dataset._variant_tensor, # pylint: disable=protected-access
+ num_shards=self._num_shards,
+ index=self._index,
+ **flat_structure(self))
+ super(ShardDataset, self).__init__(input_dataset, variant_tensor)
+
+
class BatchDataset(UnaryDataset):
"""A `Dataset` that batches contiguous elements from its input."""
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index 27a700f..dbe738b 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -47,7 +47,7 @@
":cli_test_utils",
":debug_py",
":grpc_debug_test_server",
- ":offline_analyzer",
+ ":offline_analyzer_lib",
":session_debug_testlib",
":source_remote",
] + if_not_windows([
@@ -393,6 +393,13 @@
name = "offline_analyzer",
srcs = ["cli/offline_analyzer.py"],
srcs_version = "PY2AND3",
+ deps = [":offline_analyzer_lib"],
+)
+
+py_library(
+ name = "offline_analyzer_lib",
+ srcs = ["cli/offline_analyzer.py"],
+ srcs_version = "PY2AND3",
deps = [
":analyzer_cli",
":debug_data",
@@ -404,12 +411,12 @@
py_library(
name = "debug_examples",
deps = [
- ":debug_errors",
- ":debug_fibonacci",
- ":debug_keras",
+ ":debug_errors_lib",
+ ":debug_fibonacci_lib",
+ ":debug_keras_lib",
] + if_not_v2([
- ":debug_mnist",
- ":debug_tflearn_iris",
+ ":debug_mnist_lib",
+ ":debug_tflearn_iris_lib",
]),
)
@@ -417,6 +424,13 @@
name = "debug_fibonacci",
srcs = ["examples/debug_fibonacci.py"],
srcs_version = "PY2AND3",
+ deps = [":debug_fibonacci_lib"],
+)
+
+py_library(
+ name = "debug_fibonacci_lib",
+ srcs = ["examples/debug_fibonacci.py"],
+ srcs_version = "PY2AND3",
deps = [
":debug_py",
"//tensorflow:tensorflow_py",
@@ -429,6 +443,13 @@
name = "debug_errors",
srcs = ["examples/debug_errors.py"],
srcs_version = "PY2AND3",
+ deps = [":debug_errors_lib"],
+)
+
+py_library(
+ name = "debug_errors_lib",
+ srcs = ["examples/debug_errors.py"],
+ srcs_version = "PY2AND3",
deps = [
":debug_py",
"//tensorflow:tensorflow_py",
@@ -440,6 +461,13 @@
name = "debug_mnist",
srcs = ["examples/debug_mnist.py"],
srcs_version = "PY2AND3",
+ deps = [":debug_mnist_lib"],
+)
+
+py_library(
+ name = "debug_mnist_lib",
+ srcs = ["examples/debug_mnist.py"],
+ srcs_version = "PY2AND3",
deps = [
":debug_py",
"//tensorflow:tensorflow_py",
@@ -451,6 +479,13 @@
name = "debug_tflearn_iris",
srcs = ["examples/debug_tflearn_iris.py"],
srcs_version = "PY2AND3",
+ deps = [":debug_tflearn_iris_lib"],
+)
+
+py_library(
+ name = "debug_tflearn_iris_lib",
+ srcs = ["examples/debug_tflearn_iris.py"],
+ srcs_version = "PY2AND3",
deps = [
":debug_py",
"//tensorflow:tensorflow_py",
@@ -462,6 +497,13 @@
name = "debug_keras",
srcs = ["examples/debug_keras.py"],
srcs_version = "PY2AND3",
+ deps = [":debug_keras_lib"],
+)
+
+py_library(
+ name = "debug_keras_lib",
+ srcs = ["examples/debug_keras.py"],
+ srcs_version = "PY2AND3",
deps = [
":debug_py",
"//tensorflow:tensorflow_py",
@@ -973,6 +1015,12 @@
"//tensorflow/python:training",
"//tensorflow/python:variables",
],
+ tags = [
+ "manual",
+ "no_pip",
+ "no_windows",
+ "notap",
+ ],
)
cuda_py_test(
diff --git a/tensorflow/python/debug/cli/stepper_cli.py b/tensorflow/python/debug/cli/stepper_cli.py
index 94eb275..fe1a012 100644
--- a/tensorflow/python/debug/cli/stepper_cli.py
+++ b/tensorflow/python/debug/cli/stepper_cli.py
@@ -251,6 +251,9 @@
lines.extend(
["Topologically-sorted transitive input(s) and fetch(es):", ""])
+ output = debugger_cli_common.rich_text_lines_from_rich_line_list(lines)
+ self._add_deprecation_warning(output)
+
for i, element_name in enumerate(self._sorted_nodes):
if i < index_range[0] or i >= index_range[1]:
continue
@@ -269,15 +272,36 @@
override_names,
dirty_variable_names)
- lines.append(node_prefix + "] " + element_name)
-
- output = debugger_cli_common.rich_text_lines_from_rich_line_list(lines)
+ output.append_rich_line(node_prefix + "] " + element_name)
if verbose:
output.extend(self._node_status_label_legend())
return output
+ def _add_deprecation_warning(self, message):
+ """Add deprecation warning as RichTextLines."""
+ color = "yellow"
+ message.append_rich_line(
+ debugger_cli_common.RichLine(
+ "WARNING: the invoke_stepper feature of tfdbg has been deprecated ",
+ color))
+ message.append_rich_line(
+ debugger_cli_common.RichLine(
+ "and will be removed in the next release of TensorFlow.",
+ color))
+ message.append_rich_line(debugger_cli_common.RichLine("", color))
+ message.append_rich_line(
+ debugger_cli_common.RichLine(
+ "There now exist better alternatives of stepping debugging, "
+ "including:",
+ color))
+ message.append_rich_line(
+ debugger_cli_common.RichLine("- TensorBoard Debugger Plugin", color))
+ message.append_rich_line(
+ debugger_cli_common.RichLine("- Eager Execution", color))
+ message.append_rich_line(debugger_cli_common.RichLine("", color))
+
def _get_status_labels(self,
element_name,
handle_node_names,
diff --git a/tensorflow/python/debug/cli/stepper_cli_test.py b/tensorflow/python/debug/cli/stepper_cli_test.py
index 5cf69d0..c728373 100644
--- a/tensorflow/python/debug/cli/stepper_cli_test.py
+++ b/tensorflow/python/debug/cli/stepper_cli_test.py
@@ -235,6 +235,9 @@
], output.lines)
def testContToValidNodeShouldUpdateStatus(self):
+ if test_util.is_gpu_available():
+ self.skipTest("b/123446705 this causes a segfault on GPU")
+
with stepper.NodeStepper(self.sess, self.e) as node_stepper:
cli = stepper_cli.NodeStepperCLI(node_stepper)
@@ -275,6 +278,9 @@
self.assertIn(stepper_cli.NodeStepperCLI.STATE_CONT, stat_labels[index_d])
def testSteppingOneStepAtATimeShouldUpdateStatus(self):
+ if test_util.is_gpu_available():
+ self.skipTest("b/123446705 this causes a segfault on GPU")
+
with stepper.NodeStepper(self.sess, self.e) as node_stepper:
cli = stepper_cli.NodeStepperCLI(node_stepper)
diff --git a/tensorflow/python/debug/lib/stepper_test.py b/tensorflow/python/debug/lib/stepper_test.py
index 9e78e20..bec858a 100644
--- a/tensorflow/python/debug/lib/stepper_test.py
+++ b/tensorflow/python/debug/lib/stepper_test.py
@@ -94,6 +94,9 @@
self.assertAllClose(6.0, stepper.cont("c"))
def testUsingNamesNotUsingIntermediateTensors(self):
+ if test_util.is_gpu_available():
+ self.skipTest("b/123446705 this causes a segfault on GPU")
+
with NodeStepper(self.sess, "e:0") as stepper:
# The first cont() call should have used no feeds.
result = stepper.cont("c:0")
@@ -119,6 +122,9 @@
}, stepper.last_feed_types())
def testUsingNodesNotUsingIntermediateTensors(self):
+ if test_util.is_gpu_available():
+ self.skipTest("b/123446705 this causes a segfault on GPU")
+
with NodeStepper(self.sess, self.e) as stepper:
# There should be no handles before any cont() calls.
self.assertEqual([], stepper.handle_names())
@@ -493,6 +499,9 @@
self.assertSetEqual({"ph0", "ph1"}, set(stepper.placeholders()))
def testContWithPlaceholders(self):
+ if test_util.is_gpu_available():
+ self.skipTest("b/123446705 this causes a segfault on GPU")
+
with NodeStepper(
self.sess,
self.y,
@@ -739,6 +748,9 @@
ops.reset_default_graph()
def testContToUpdateA(self):
+ if test_util.is_gpu_available():
+ self.skipTest("b/123446705 this causes a segfault on GPU")
+
with NodeStepper(self.sess, "optim") as stepper:
result = stepper.cont("a:0")
self.assertAllClose(1.0, result)
@@ -887,6 +899,8 @@
"clean" means no Variables have been updated by preceding cont() calls.
"""
+ if test_util.is_gpu_available():
+ self.skipTest("b/123446705 this causes a segfault on GPU")
with NodeStepper(self.sess, "optim") as stepper:
# First, call cont() on the two tensors on the intermediate level: e and
@@ -979,6 +993,8 @@
def testOverrideThenContToUpdateThenRemoveOverrideThenUpdateAgain(self):
"""Test cont() to update nodes after overriding tensor values."""
+ if test_util.is_gpu_available():
+ self.skipTest("b/123446705 this causes a segfault on GPU")
with NodeStepper(self.sess, "optim") as stepper:
result = stepper.cont("d:0")
diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD
index 77ec4a5..9bcdcce 100644
--- a/tensorflow/python/distribute/BUILD
+++ b/tensorflow/python/distribute/BUILD
@@ -138,7 +138,6 @@
"//tensorflow/python:variable_scope",
"//tensorflow/python/data",
"//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
- "//tensorflow/python/ops/losses",
"//tensorflow/tools/docs:doc_controls",
],
)
@@ -284,6 +283,28 @@
)
py_library(
+ name = "collective_all_reduce_strategy",
+ srcs = ["collective_all_reduce_strategy.py"],
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ ":mirrored_strategy",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:collective_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python/distribute:cross_device_ops",
+ "//tensorflow/python/distribute:cross_device_utils",
+ "//tensorflow/python/distribute:input_lib",
+ "//tensorflow/python/distribute:multi_worker_util",
+ "//tensorflow/python/distribute:numpy_dataset",
+ "//tensorflow/python/distribute:values",
+ "//tensorflow/python/distribute/cluster_resolver:cluster_resolver_lib",
+ "//tensorflow/python/eager:context",
+ ],
+)
+
+py_library(
name = "multi_worker_util",
srcs = [
"multi_worker_util.py",
diff --git a/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver.py b/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver.py
index b8d2ecc..abf628b 100644
--- a/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver.py
+++ b/tensorflow/python/distribute/cluster_resolver/tpu_cluster_resolver.py
@@ -380,7 +380,7 @@
def get_job_name(self):
if (self._shouldResolve() or
- self._tpu.startswith(compat.as_bytes('grpc://'))):
+ self._isRunningInGCE()):
return self.task_type
def cluster_spec(self):
diff --git a/tensorflow/python/distribute/collective_all_reduce_strategy.py b/tensorflow/python/distribute/collective_all_reduce_strategy.py
new file mode 100644
index 0000000..74d3030
--- /dev/null
+++ b/tensorflow/python/distribute/collective_all_reduce_strategy.py
@@ -0,0 +1,387 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Class CollectiveAllReduceStrategy implementing DistributionStrategy."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+
+from tensorflow.core.protobuf import rewriter_config_pb2
+from tensorflow.python.distribute import cross_device_ops as cross_device_ops_lib
+from tensorflow.python.distribute import cross_device_utils
+from tensorflow.python.distribute import device_util
+from tensorflow.python.distribute import distribute_lib
+from tensorflow.python.distribute import input_lib
+from tensorflow.python.distribute import mirrored_strategy
+from tensorflow.python.distribute import multi_worker_util
+from tensorflow.python.distribute import numpy_dataset
+from tensorflow.python.distribute import values
+from tensorflow.python.distribute.cluster_resolver import SimpleClusterResolver
+from tensorflow.python.distribute.cluster_resolver import TFConfigClusterResolver
+from tensorflow.python.eager import context
+from tensorflow.python.eager import tape
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import collective_ops
+from tensorflow.python.platform import tf_logging as logging
+
+
+# TODO(yuefengz): support in-graph replication.
+class CollectiveAllReduceStrategy(distribute_lib.DistributionStrategy):
+ """Distribution strategy that uses collective ops for all-reduce.
+
+ It is similar to MirroredStrategy but it uses collective ops for reduction.
+
+ By default it uses all local GPUs or CPU for single-worker training.
+
+ When 'TF_CONFIG' environment variable is given, it parses cluster_spec,
+ task_type and task_id from 'TF_CONFIG' and turns into a multi-worker strategy
+ which mirrores models on GPUs of all machines in a cluster. In the current
+ implementation, it uses all GPUs in a cluster and it assumes all workers have
+ the same number of GPUs.
+ """
+
+ def __init__(self):
+ """Initializes the object."""
+ super(CollectiveAllReduceStrategy, self).__init__(
+ CollectiveAllReduceExtended(self))
+
+
+class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended):
+ """Implementation of CollectiveAllReduceStrategy."""
+
+ def __init__(self,
+ container_strategy,
+ cluster_resolver=TFConfigClusterResolver()):
+ distribute_lib.DistributionStrategyExtended.__init__(
+ self, container_strategy)
+ self._cross_device_ops = None
+ self._initialize_strategy(cluster_resolver)
+ assert isinstance(self._get_cross_device_ops(),
+ cross_device_ops_lib.CollectiveAllReduce)
+
+ def _initialize_strategy(self, cluster_resolver):
+ if cluster_resolver.cluster_spec().as_dict():
+ self._initialize_multi_worker(cluster_resolver)
+ else:
+ self._initialize_local(cluster_resolver)
+ # Save the num_gpus_per_worker for configure method.
+ self._num_gpus_per_worker = cluster_resolver.num_accelerators()
+
+ def _initialize_local(self, cluster_resolver):
+ """Initializes the object for local training."""
+ self._is_chief = True
+ self._num_workers = 1
+
+ num_gpus = cluster_resolver.num_accelerators()
+ if num_gpus:
+ local_devices = tuple("/device:GPU:%d" % i for i in range(num_gpus))
+ else:
+ local_devices = ("/device:CPU:0",)
+ self._worker_device = device_util.canonicalize("/device:CPU:0")
+ self._host_input_device = numpy_dataset.SingleDevice(self._worker_device)
+
+ self._collective_keys = cross_device_utils.CollectiveKeys()
+ super(CollectiveAllReduceExtended, self)._initialize_local(local_devices)
+ # TODO(yuefengz): remove num_gpus_per_worker from CollectiveAllReduce.
+ self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
+ num_workers=self._num_workers,
+ num_gpus_per_worker=num_gpus,
+ collective_keys=self._collective_keys)
+
+ self._cluster_spec = None
+ self._task_type = None
+ self._task_id = None
+
+ logging.info("CollectiveAllReduceStrategy with local_devices = %r",
+ local_devices)
+
+ def _initialize_multi_worker(self, cluster_resolver):
+ """Initializes the object for multi-worker training."""
+ # TODO(yuefengz): The `num_gpus` is only for this particular task. It
+ # assumes all workers have the same number of GPUs. We should remove this
+ # assumption by querying all tasks for their numbers of GPUs.
+ num_gpus = cluster_resolver.num_accelerators()
+ cluster_spec = multi_worker_util.normalize_cluster_spec(
+ cluster_resolver.cluster_spec())
+ task_type = cluster_resolver.task_type
+ task_id = cluster_resolver.task_id
+ if task_type is None or task_id is None:
+ raise ValueError("When `cluster_spec` is given, you must also specify "
+ "`task_type` and `task_id` in the `cluster_resolver`.")
+ if task_type not in ("chief", "worker"):
+ raise ValueError(
+ "Unrecognized task_type: %r, valid task types are: \"chief\", "
+ "\"worker\"." % task_type)
+
+ self._num_workers = multi_worker_util.worker_count(cluster_spec, task_type)
+ if not self._num_workers:
+ raise ValueError("No `worker` or `chief` tasks can be found in "
+ "`cluster_spec`.")
+
+ self._is_chief = multi_worker_util.is_chief(cluster_spec, task_type,
+ task_id)
+
+ self._worker_device = "/job:%s/task:%d" % (task_type, task_id)
+ self._host_input_device = numpy_dataset.SingleDevice(self._worker_device)
+ if num_gpus:
+ local_devices = tuple("%s/device:GPU:%d" % (self._worker_device, i)
+ for i in range(num_gpus))
+ else:
+ local_devices = (self._worker_device,)
+
+ self._collective_keys = cross_device_utils.CollectiveKeys()
+ super(CollectiveAllReduceExtended, self)._initialize_local(local_devices)
+ self._input_workers = input_lib.InputWorkers(
+ self._device_map, [(self._worker_device, self.worker_devices)])
+ self._cross_device_ops = cross_device_ops_lib.CollectiveAllReduce(
+ num_workers=self._num_workers,
+ num_gpus_per_worker=num_gpus,
+ collective_keys=self._collective_keys)
+
+ # Add a default device so that ops without specified devices will not end up
+ # on other workers.
+ self._default_device = "/job:%s/task:%d" % (task_type, task_id)
+
+ self._cluster_spec = cluster_spec
+ self._task_type = task_type
+ self._task_id = task_id
+
+ logging.info(
+ "Multi-worker CollectiveAllReduceStrategy with "
+ "cluster_spec = %r, task_type = %r, task_id = %r, "
+ "num_workers = %r, local_devices = %r", cluster_spec.as_dict(),
+ task_type, task_id, self._num_workers, local_devices)
+
+ def _create_variable(self, next_creator, *args, **kwargs):
+ colocate_with = kwargs.pop("colocate_with", None)
+ if colocate_with is None:
+ device_map = self._device_map
+ logical_device = 0 # TODO(josh11b): Get logical device from scope here.
+ elif isinstance(colocate_with, numpy_dataset.SingleDevice):
+ with ops.device(colocate_with.device):
+ return next_creator(*args, **kwargs)
+ else:
+ device_map = colocate_with.device_map
+ logical_device = colocate_with.logical_device
+
+ def _real_mirrored_creator(devices, *args, **kwargs):
+ """Creates one MirroredVariable on the current worker."""
+ unique_var_name = ops.get_default_graph().unique_name(
+ kwargs["name"], mark_as_used=False).rstrip("/")
+ # pylint: disable=protected-access
+ collective_instance_key = self._collective_keys.get_instance_key(
+ key_id=unique_var_name)
+ # Only the first device participles in the broadcast of initial values.
+ group_key = self._collective_keys.get_group_key([devices[0]])
+ group_size = self._num_workers
+ if "initial_value" not in kwargs:
+ raise ValueError("Initial value must be specified.")
+ initial_value = kwargs["initial_value"]
+ if callable(initial_value):
+ initial_value_fn = initial_value
+ else:
+ initial_value_fn = lambda: initial_value
+
+ value_list = []
+ for i, d in enumerate(devices):
+ with ops.init_scope(), ops.device(d):
+ if i == 0:
+ # The initial value fn makes sure variables all initialized to
+ # same values. The first device of the chief worker will send their
+ # variable values to other workers.
+ def _overridden_initial_value_fn(device=d, index=i): # pylint: disable=g-missing-docstring
+ with ops.device(device):
+ initial_value = initial_value_fn()
+ assert not callable(initial_value)
+ initial_value = ops.convert_to_tensor(initial_value)
+
+ assert index == 0, index
+ if self._num_workers > 1:
+ if self._is_chief:
+ bcast_send = collective_ops.broadcast_send(
+ initial_value, initial_value.shape, initial_value.dtype,
+ group_size, group_key, collective_instance_key)
+ with ops.control_dependencies([bcast_send]):
+ return array_ops.identity(initial_value)
+ else:
+ return collective_ops.broadcast_recv(
+ initial_value.shape, initial_value.dtype, group_size,
+ group_key, collective_instance_key)
+ return initial_value
+ else:
+ # Give replicas meaningful distinct names:
+ var0name = value_list[0].name.split(":")[0]
+ # We append a / to variable names created on replicas with id > 0 to
+ # ensure that we ignore the name scope and instead use the given
+ # name as the absolute name of the variable.
+ kwargs["name"] = "%s/replica_%d/" % (var0name, i)
+
+ # Variables on non-first replica get initial values from the
+ # variables created on the first device of each worker.
+ def _overridden_initial_value_fn(device=d, index=i):
+ assert index > 0
+ with ops.device(device):
+ if context.executing_eagerly():
+ return array_ops.identity(value_list[0].value())
+ else:
+ return array_ops.identity(value_list[0].initial_value)
+
+ kwargs["initial_value"] = _overridden_initial_value_fn
+ with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
+ # Don't record operations (e.g. other variable reads) during
+ # variable creation.
+ with tape.stop_recording():
+ v = next_creator(*args, **kwargs)
+
+ if i == 0:
+ actual_var_name = v.name.split(":")[0]
+ assert unique_var_name == actual_var_name, "%r vs %r" % (
+ unique_var_name, actual_var_name)
+ assert not isinstance(v, values.DistributedVariable)
+ value_list.append(v)
+ return value_list
+
+ # pylint: disable=protected-access
+ return mirrored_strategy._create_mirrored_variable(
+ self._container_strategy(), device_map, logical_device,
+ _real_mirrored_creator, *args, **kwargs)
+
+ def _make_dataset_iterator(self, dataset):
+ return input_lib.DatasetIterator(dataset, self._input_workers,
+ self._num_replicas_in_sync)
+
+ def _make_input_fn_iterator(
+ self,
+ input_fn,
+ replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
+ """Distributes the dataset to each local GPU."""
+ if self._cluster_spec is None:
+ input_pipeline_id = 0
+ else:
+ input_pipeline_id = multi_worker_util.id_in_cluster(
+ self._cluster_spec, self._task_type, self._task_id)
+ input_context = distribute_lib.InputContext(
+ num_input_pipelines=self._num_workers,
+ input_pipeline_id=input_pipeline_id,
+ num_replicas_in_sync=self._num_replicas_in_sync)
+
+ return input_lib.InputFunctionIterator(
+ input_fn, self._input_workers, [input_context])
+
+ def _configure(self,
+ session_config=None,
+ cluster_spec=None,
+ task_type=None,
+ task_id=None):
+ """Configures the object.
+
+ Args:
+ session_config: a `tf.ConfigProto`
+ cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
+ cluster configurations.
+ task_type: the current task type, such as "worker".
+ task_id: the current task id.
+
+ Raises:
+ ValueError: if `task_type` is not in the `cluster_spec`.
+ """
+ if cluster_spec:
+ # Use the num_gpus_per_worker recorded in constructor since _configure
+ # doesn't take num_gpus.
+ cluster_resolver = SimpleClusterResolver(
+ cluster_spec=multi_worker_util.normalize_cluster_spec(cluster_spec),
+ task_type=task_type,
+ task_id=task_id,
+ num_accelerators=self._num_gpus_per_worker)
+ self._initialize_multi_worker(cluster_resolver)
+ assert isinstance(self._get_cross_device_ops(),
+ cross_device_ops_lib.CollectiveAllReduce)
+
+ if session_config:
+ session_config.CopyFrom(self._update_config_proto(session_config))
+
+ def _update_config_proto(self, config_proto):
+ updated_config = copy.deepcopy(config_proto)
+ # Enable the scoped allocator optimization for CollectiveOps. This
+ # optimization converts many small all-reduces into fewer larger
+ # all-reduces.
+ rewrite_options = updated_config.graph_options.rewrite_options
+ rewrite_options.scoped_allocator_optimization = (
+ rewriter_config_pb2.RewriterConfig.ON)
+ # We turn on ScopedAllocator only for CollectiveReduce op, i.e. enable_op =
+ # ["CollectiveReduce"]. Since we can't assign to a repeated proto field, we
+ # clear and then append.
+ del rewrite_options.scoped_allocator_opts.enable_op[:]
+ rewrite_options.scoped_allocator_opts.enable_op.append("CollectiveReduce")
+
+ if not self._cluster_spec:
+ return updated_config
+
+ assert self._task_type
+ assert self._task_id is not None
+
+ # Collective group leader is needed for collective ops to coordinate
+ # workers.
+ if "chief" in self._cluster_spec.jobs:
+ updated_config.experimental.collective_group_leader = (
+ "/job:chief/replica:0/task:0")
+ else:
+ if "worker" not in self._cluster_spec.jobs:
+ raise ValueError(
+ "You must have `chief` or `worker` jobs in the `cluster_spec`.")
+ updated_config.experimental.collective_group_leader = (
+ "/job:worker/replica:0/task:0")
+
+ # The device filters prevent communication between workers.
+ del updated_config.device_filters[:]
+ updated_config.device_filters.append(
+ "/job:%s/task:%d" % (self._task_type, self._task_id))
+
+ return updated_config
+
+ @property
+ def experimental_between_graph(self):
+ return True
+
+ @property
+ def experimental_should_init(self):
+ return True
+
+ @property
+ def should_checkpoint(self):
+ return self._is_chief
+
+ @property
+ def should_save_summary(self):
+ return self._is_chief
+
+ @property
+ def _num_replicas_in_sync(self):
+ return len(self.worker_devices) * self._num_workers
+
+ # TODO(priyag): Delete this once all strategies use global batch size.
+ @property
+ def _global_batch_size(self):
+ """`make_dataset_iterator` and `make_numpy_iterator` use global batch size.
+
+ `make_input_fn_iterator` assumes per-replica batching.
+
+ Returns:
+ Boolean.
+ """
+ return True
diff --git a/tensorflow/python/distribute/cross_device_utils.py b/tensorflow/python/distribute/cross_device_utils.py
index e8066dd..cb5417e 100644
--- a/tensorflow/python/distribute/cross_device_utils.py
+++ b/tensorflow/python/distribute/cross_device_utils.py
@@ -30,7 +30,7 @@
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import collective_ops
-from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import gradients_util
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nccl_ops
@@ -645,14 +645,14 @@
def aggregate_tensors_or_indexed_slices(values, accumulation_fn=math_ops.add_n):
"""Aggregate tensors using `accumulation_fn` and IndexedSlices via concat."""
if any(isinstance(v, ops.IndexedSlices) for v in values):
- return gradients_impl._AggregateIndexedSlicesGradients(values) # pylint: disable=protected-access
+ return gradients_util._AggregateIndexedSlicesGradients(values) # pylint: disable=protected-access
else:
return accumulation_fn(values)
def divide_by_n_tensors_or_indexed_slices(value, n):
if isinstance(value, ops.IndexedSlices):
- value = gradients_impl._HandleNestedIndexedSlices(value) # pylint: disable=protected-access
+ value = gradients_util._HandleNestedIndexedSlices(value) # pylint: disable=protected-access
return ops.IndexedSlices(
value.values / n, value.indices, value.dense_shape)
else:
diff --git a/tensorflow/python/distribute/distribute_coordinator.py b/tensorflow/python/distribute/distribute_coordinator.py
index f1d7d1d..eb3fd1d 100644
--- a/tensorflow/python/distribute/distribute_coordinator.py
+++ b/tensorflow/python/distribute/distribute_coordinator.py
@@ -34,6 +34,9 @@
from tensorflow.python.training import server_lib
+_thread_local = threading.local()
+
+
class _TaskType(object):
PS = "ps"
WORKER = "worker"
@@ -383,6 +386,27 @@
rpc_layer=None,
environment=None):
"""Runs a standard server."""
+ # Check if the Server is already running. If so, assert that no configuration
+ # options have changed, and return the existing Server. This allows us to
+ # call `run_distribute_coordinator` multiple times.
+ if getattr(_thread_local, "server", None) is not None:
+ assert _thread_local.cluster_spec == cluster_spec
+ assert _thread_local.task_type == task_type
+ assert _thread_local.task_id == task_id
+ assert _thread_local.session_config_str == repr(session_config)
+ assert _thread_local.rpc_layer == rpc_layer
+ assert _thread_local.environment == environment
+ return _thread_local.server
+ else:
+ # This method is not thread-safe.
+ _thread_local.server_started = True
+ _thread_local.cluster_spec = cluster_spec
+ _thread_local.task_type = task_type
+ _thread_local.task_id = task_id
+ _thread_local.session_config_str = repr(session_config)
+ _thread_local.rpc_layer = rpc_layer
+ _thread_local.environment = environment
+
assert cluster_spec
target = cluster_spec.task_address(task_type, task_id)
if rpc_layer:
@@ -404,8 +428,6 @@
if environment == "google":
server = _FakeServer()
- server.start()
- return server
else:
if session_config:
logging.info(
@@ -420,8 +442,10 @@
task_index=task_id,
config=session_config,
protocol=rpc_layer)
- server.start()
- return server
+
+ server.start()
+ _thread_local.server = server
+ return server
def _run_between_graph_client(worker_fn, strategy, eval_fn, eval_strategy,
@@ -648,7 +672,7 @@
for a task. The distribute coordinator will make a copy of the `strategy`
object, call its `configure` method and pass it to `worker_fn` as an argument.
- The `worker_fn` defines the training logic and is called under a its own
+ The `worker_fn` defines the training logic and is called under its own
worker context which can be accessed to via `get_current_worker_context`. A
worker context provides access to configurations for each task, e.g. the
task_type, task_id, master target and so on. Since `worker_fn` will be called
@@ -674,7 +698,7 @@
the worker context.
The `cluster_spec` can be either passed by the argument or parsed from the
- "TF_CONFIG" envrionment variable. Example of a TF_CONFIG:
+ "TF_CONFIG" environment variable. Example of a TF_CONFIG:
```
cluster = {'chief': ['host0:2222'],
'ps': ['host1:2222', 'host2:2222'],
@@ -689,19 +713,19 @@
will be created to call `eval_fn` with its `task_type` set to "evaluator". If
`eval_fn` is not defined, fall back to `worker_fn`. This implies that
evaluation will be done on a single machine if there is an "evaluator" task.
- If "evaluator" doesn't exit in the cluster_spec, it entirely depends on the
+ If "evaluator" doesn't exist in the cluster_spec, it entirely depends on the
`worker_fn` for how to do evaluation.
Args:
worker_fn: the function to be called. The function should accept a
`strategy` object and will be given access to a context object via a
context manager scope.
- strategy: a DistributionStrategy object which specifying whether it should
+ strategy: a DistributionStrategy object specifying whether it should
run between-graph replicated training or not, whether to run init ops,
etc. This object will also be configured given `session_config`,
`cluster_spec`, `task_type` and `task_id`.
eval_fn: optional function for "evaluator" task. If `eval_fn` is not passed
- in but a "evaluator" task found in the `cluster_spec`, the `worker_fn`
+ in but a "evaluator" task is found in the `cluster_spec`, the `worker_fn`
will be used for this task.
eval_strategy: optional DistributionStrategy object for "evaluator" task.
mode: in which mode this distribute coordinator runs.
@@ -719,7 +743,8 @@
Returns:
In the client job, return the value returned by `worker_fn` if
- it is in-graph replication; return None otherwise.
+ it is in-graph replication or INDEPENDENT_WORKER mode; return None
+ otherwise.
"""
tf_config = json.loads(os.environ.get("TF_CONFIG", "{}"))
if not cluster_spec:
@@ -736,7 +761,7 @@
rpc_layer = tf_config.get("rpc_layer", rpc_layer)
environment = tf_config.get("environment", None)
- # Setting the session config is necessary for some strategies such
+ # Setting the session config is necessary for some strategies such as
# CollectiveAllReduceStrategy.
session_config = session_config or config_pb2.ConfigProto(
allow_soft_placement=True)
@@ -813,23 +838,22 @@
session_config=session_config,
rpc_layer=rpc_layer,
environment=environment)
-
if task_type in [_TaskType.CHIEF, _TaskType.WORKER]:
if strategy.extended.experimental_between_graph:
# All jobs run `worker_fn` if between-graph.
- _run_single_worker(worker_fn, strategy, cluster_spec, task_type,
- task_id, session_config, rpc_layer)
+ return _run_single_worker(worker_fn, strategy, cluster_spec, task_type,
+ task_id, session_config, rpc_layer)
else:
# Only one node runs `worker_fn` if in-graph.
context = _WorkerContext(strategy, cluster_spec, task_type, task_id)
if context.is_chief:
- _run_single_worker(worker_fn, strategy, cluster_spec, None, None,
- session_config, rpc_layer)
+ return _run_single_worker(worker_fn, strategy, cluster_spec, None,
+ None, session_config, rpc_layer)
else:
server.join()
elif task_type == _TaskType.EVALUATOR:
- _run_single_worker(eval_fn, eval_strategy, cluster_spec, task_type,
- task_id, session_config, rpc_layer)
+ return _run_single_worker(eval_fn, eval_strategy, cluster_spec, task_type,
+ task_id, session_config, rpc_layer)
else:
if task_type != _TaskType.PS:
raise ValueError("Unexpected task_type: %r" % task_type)
diff --git a/tensorflow/python/distribute/distribute_coordinator_test.py b/tensorflow/python/distribute/distribute_coordinator_test.py
index ceb4483..2299716 100644
--- a/tensorflow/python/distribute/distribute_coordinator_test.py
+++ b/tensorflow/python/distribute/distribute_coordinator_test.py
@@ -864,6 +864,9 @@
cluster_spec = {"worker": ["localhost:0"]}
tf_config = {"cluster": cluster_spec}
+ # Reset the saved Server state.
+ distribute_coordinator._thread_local = threading.local() # pylint: disable=protected-access
+
with test.mock.patch.dict("os.environ",
{"TF_CONFIG": json.dumps(tf_config)}):
distribute_coordinator.run_distribute_coordinator(
diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py
index 2cc99b3..3e48364 100644
--- a/tensorflow/python/distribute/distribute_lib.py
+++ b/tensorflow/python/distribute/distribute_lib.py
@@ -37,7 +37,6 @@
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
-from tensorflow.python.ops.losses import losses_impl
from tensorflow.python.platform import tf_logging
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
@@ -79,14 +78,14 @@
# Public utility functions.
-@tf_export("distribute.get_loss_reduction")
+@tf_export(v1=["distribute.get_loss_reduction"])
def get_loss_reduction():
- """`tf.distribute.ReduceOp` corresponding to the last loss reduction."""
- loss_reduction = ops.get_default_graph()._last_loss_reduction # pylint: disable=protected-access
- if (loss_reduction == losses_impl.Reduction.SUM or
- loss_reduction == losses_impl.ReductionV2.SUM):
- return reduce_util.ReduceOp.SUM
- return reduce_util.ReduceOp.MEAN
+ """DEPRECATED: Now always returns `tf.distribute.ReduceOp.SUM`.
+
+ We now always make the complete adjustment when computing the loss, so
+ code should always add gradients/losses across replicas, never average.
+ """
+ return reduce_util.ReduceOp.SUM
# ------------------------------------------------------------------------------
diff --git a/tensorflow/python/distribute/distribute_lib_test.py b/tensorflow/python/distribute/distribute_lib_test.py
index c147849..6876af3 100644
--- a/tensorflow/python/distribute/distribute_lib_test.py
+++ b/tensorflow/python/distribute/distribute_lib_test.py
@@ -19,7 +19,7 @@
from __future__ import print_function
from tensorflow.python.distribute import distribute_lib
-from tensorflow.python.distribute import distribution_strategy_context
+from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import variable_scope
@@ -60,13 +60,12 @@
def _assert_in_default_state(t):
- t.assertIs(distribution_strategy_context._get_default_replica_context(),
- distribution_strategy_context.get_replica_context())
- t.assertIs(None, distribution_strategy_context.get_cross_replica_context())
- t.assertFalse(distribution_strategy_context.in_cross_replica_context())
- t.assertIs(distribution_strategy_context._get_default_strategy(),
- distribution_strategy_context.get_strategy())
- t.assertFalse(distribution_strategy_context.has_strategy())
+ t.assertIs(ds_context._get_default_replica_context(),
+ ds_context.get_replica_context())
+ t.assertIs(None, ds_context.get_cross_replica_context())
+ t.assertFalse(ds_context.in_cross_replica_context())
+ t.assertIs(ds_context._get_default_strategy(), ds_context.get_strategy())
+ t.assertFalse(ds_context.has_strategy())
class TestStrategyTest(test.TestCase):
@@ -76,14 +75,12 @@
dist = _TestStrategy()
def run_fn():
- replica_context = distribution_strategy_context.get_replica_context()
+ replica_context = ds_context.get_replica_context()
self.assertTrue(replica_context is not None)
- self.assertIs(None,
- distribution_strategy_context.get_cross_replica_context())
- self.assertFalse(distribution_strategy_context.in_cross_replica_context())
- self.assertTrue(distribution_strategy_context.has_strategy())
- self.assertIs(dist,
- distribution_strategy_context.get_strategy())
+ self.assertIs(None, ds_context.get_cross_replica_context())
+ self.assertFalse(ds_context.in_cross_replica_context())
+ self.assertTrue(ds_context.has_strategy())
+ self.assertIs(dist, ds_context.get_strategy())
self.assertEqual("foo", replica_context.merge_call(None, test_arg="foo"))
expected_value = _get_test_variable(
"bar", variable_scope.VariableSynchronization.AUTO,
@@ -101,13 +98,11 @@
_assert_in_default_state(self)
dist = _TestStrategy()
with dist.scope():
- self.assertIs(None, distribution_strategy_context.get_replica_context())
- self.assertIs(dist,
- distribution_strategy_context.get_cross_replica_context())
- self.assertTrue(distribution_strategy_context.in_cross_replica_context())
- self.assertTrue(distribution_strategy_context.has_strategy())
- self.assertIs(dist,
- distribution_strategy_context.get_strategy())
+ self.assertIs(None, ds_context.get_replica_context())
+ self.assertIs(dist, ds_context.get_cross_replica_context())
+ self.assertTrue(ds_context.in_cross_replica_context())
+ self.assertTrue(ds_context.has_strategy())
+ self.assertIs(dist, ds_context.get_strategy())
expected_value = _get_test_variable(
"baz", variable_scope.VariableSynchronization.AUTO,
variable_scope.VariableAggregation.NONE)
@@ -138,22 +133,16 @@
_assert_in_default_state(self)
def merge_fn(dist, s):
- self.assertIs(
- distribution_strategy_context._get_default_strategy(),
- dist)
- self.assertIs(None, distribution_strategy_context.get_replica_context())
- self.assertIs(dist,
- distribution_strategy_context.get_cross_replica_context())
- self.assertTrue(distribution_strategy_context.in_cross_replica_context())
- self.assertIs(dist,
- distribution_strategy_context.get_strategy())
- self.assertFalse(
- distribution_strategy_context.has_strategy())
+ self.assertIs(ds_context._get_default_strategy(), dist)
+ self.assertIs(None, ds_context.get_replica_context())
+ self.assertIs(dist, ds_context.get_cross_replica_context())
+ self.assertTrue(ds_context.in_cross_replica_context())
+ self.assertIs(dist, ds_context.get_strategy())
+ self.assertFalse(ds_context.has_strategy())
return "foo_" + s
- replica_ctx = distribution_strategy_context.get_replica_context()
- self.assertIs(distribution_strategy_context._get_default_replica_context(),
- replica_ctx)
+ replica_ctx = ds_context.get_replica_context()
+ self.assertIs(ds_context._get_default_replica_context(), replica_ctx)
self.assertEqual("foo_bar", replica_ctx.merge_call(merge_fn, args=("bar",)))
_assert_in_default_state(self)
diff --git a/tensorflow/python/distribute/estimator_training.py b/tensorflow/python/distribute/estimator_training.py
index 7d5f231..0ec6703 100644
--- a/tensorflow/python/distribute/estimator_training.py
+++ b/tensorflow/python/distribute/estimator_training.py
@@ -24,6 +24,7 @@
from tensorflow.python.distribute import distribute_coordinator as dc
from tensorflow.python.distribute import distribute_coordinator_context as dc_context
+from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
@@ -296,10 +297,11 @@
assert estimator._config._distribute_coordinator_mode
run_config = estimator._config
assert estimator._config.cluster_spec
- cluster_spec = estimator._config.cluster_spec
+ cluster_spec = multi_worker_util.normalize_cluster_spec(
+ estimator._config.cluster_spec)
assert estimator._config._train_distribute
- if 'evaluator' in cluster_spec:
+ if 'evaluator' in cluster_spec.jobs:
raise ValueError("'evaluator' job is not supported if you don't use "
'`train_and_evaluate`')
@@ -344,10 +346,11 @@
assert estimator._config._distribute_coordinator_mode
run_config = estimator._config
assert estimator._config.cluster_spec
- cluster_spec = estimator._config.cluster_spec
+ cluster_spec = multi_worker_util.normalize_cluster_spec(
+ estimator._config.cluster_spec)
assert estimator._config._eval_distribute
- if 'evaluator' in cluster_spec:
+ if 'evaluator' in cluster_spec.jobs:
raise ValueError("'evaluator' job is not supported if you don't use "
'`train_and_evaluate`')
diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py
index be6b713..6b13db3 100644
--- a/tensorflow/python/distribute/input_lib.py
+++ b/tensorflow/python/distribute/input_lib.py
@@ -345,7 +345,7 @@
# TODO(sourabhbajaj): Remove this in lieu of distributed datasets
def _get_batched_dataset(d):
- """Get the underlying batch dataset from the dataset object."""
+ """Get the batched dataset from `d`."""
# pylint: disable=protected-access
if isinstance(d, dataset_ops.DatasetV1Adapter):
d = d._dataset
@@ -361,24 +361,17 @@
"The batch operations can be followed by a prefetch.")
-def _get_batched_dataset_attributes(dataset):
- """Get `batch_size`, `drop_remainder`, and `prefetch_buffer` of dataset."""
+def _get_batched_dataset_attributes(d):
+ """Get `batch_size`, `drop_remainder` of dataset."""
# pylint: disable=protected-access
- assert isinstance(dataset,
+ assert isinstance(d,
(dataset_ops.BatchDataset, batching._MapAndBatchDataset))
- if isinstance(dataset, dataset_ops.BatchDataset):
- batch_size = dataset._batch_size
- drop_remainder = dataset._drop_remainder
- elif isinstance(dataset, batching._MapAndBatchDataset):
- batch_size = dataset._batch_size_t
- drop_remainder = dataset._drop_remainder_t
-
- prefetch_buffer = None
- if isinstance(dataset, dataset_ops.PrefetchDataset):
- prefetch_buffer = dataset._buffer_size
- elif (isinstance(dataset, dataset_ops.DatasetV1Adapter)
- and isinstance(dataset._dataset, dataset_ops.PrefetchDataset)):
- prefetch_buffer = dataset._dataset._buffer_size
+ if isinstance(d, dataset_ops.BatchDataset):
+ batch_size = d._batch_size
+ drop_remainder = d._drop_remainder
+ elif isinstance(d, batching._MapAndBatchDataset):
+ batch_size = d._batch_size_t
+ drop_remainder = d._drop_remainder_t
# pylint: enable=protected-access
if tensor_util.is_tensor(batch_size):
@@ -387,14 +380,35 @@
if tensor_util.is_tensor(drop_remainder):
drop_remainder = tensor_util.constant_value(drop_remainder)
+ return batch_size, drop_remainder
+
+
+# TODO(sourabhbajaj): Remove this in lieu of distributed datasets
+def _get_dataset_attributes(dataset):
+ """Get the underlying attributes from the dataset object."""
+ # pylint: disable=protected-access
+
+ # First, get batch_size and drop_remainder from the dataset. We need
+ # to walk back the dataset creation process and find the batched version in
+ # order to get the attributes.
+ batched_dataset = _get_batched_dataset(dataset)
+ batch_size, drop_remainder = _get_batched_dataset_attributes(batched_dataset)
+
+ # Second, prefetch buffer should be get from the original dataset.
+ prefetch_buffer = None
+ if isinstance(dataset, dataset_ops.PrefetchDataset):
+ prefetch_buffer = dataset._buffer_size
+ elif (isinstance(dataset, dataset_ops.DatasetV1Adapter)
+ and isinstance(dataset._dataset, dataset_ops.PrefetchDataset)):
+ prefetch_buffer = dataset._dataset._buffer_size
+
return batch_size, drop_remainder, prefetch_buffer
def _split_dataset_batch(dataset, split_batch_by):
"""Divide a batch-ed dataset's batches into smaller batches."""
- batched_dataset = _get_batched_dataset(dataset)
batch_size, drop_remainder, prefetch_buffer = (
- _get_batched_dataset_attributes(batched_dataset))
+ _get_dataset_attributes(dataset))
if batch_size % split_batch_by:
raise ValueError(
diff --git a/tensorflow/python/distribute/one_device_strategy.py b/tensorflow/python/distribute/one_device_strategy.py
index b2d2f03..ff9f616 100644
--- a/tensorflow/python/distribute/one_device_strategy.py
+++ b/tensorflow/python/distribute/one_device_strategy.py
@@ -50,7 +50,6 @@
def __init__(self, container_strategy, device):
super(OneDeviceExtended, self).__init__(container_strategy)
self._device = device
- self._default_device = device
self._input_device = device_util.canonicalize("/device:CPU:0")
worker_device_pairs = [(self._input_device, [self._device])]
device_map = values.SingleDeviceMap(device)
@@ -62,8 +61,12 @@
if colocate_with is None:
with ops.device(self._device):
return next_creator(*args, **kwargs)
- with ops.colocate_with(colocate_with):
- return next_creator(*args, **kwargs)
+ elif isinstance(colocate_with, numpy_dataset.SingleDevice):
+ with ops.device(colocate_with.device):
+ return next_creator(*args, **kwargs)
+ else:
+ with ops.colocate_with(colocate_with):
+ return next_creator(*args, **kwargs)
def _validate_colocate_with_variable(self, colocate_with_variable):
values.validate_colocate(colocate_with_variable, self)
@@ -83,7 +86,7 @@
def _experimental_make_numpy_dataset(self, numpy_input, session):
return numpy_dataset.one_host_numpy_dataset(
- numpy_input, self._input_device, session)
+ numpy_input, numpy_dataset.SingleDevice(self._input_device), session)
def _broadcast_to(self, tensor, destinations):
del destinations
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index 6f798fc..aba342b 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -6,6 +6,7 @@
"//tensorflow/tools/test:performance.bzl",
"tf_py_logged_benchmark",
)
+load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test")
cc_library(
name = "pywrap_tfe_lib",
@@ -523,6 +524,22 @@
],
)
+tf_xla_py_test(
+ name = "def_function_xla_test",
+ srcs = ["def_function_xla_test.py"],
+ tags = [
+ "no_pip",
+ "nomac",
+ ],
+ deps = [
+ ":def_function",
+ "//tensorflow/compiler/tests:xla_test",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:framework_ops",
+ ],
+)
+
py_library(
name = "wrap_function",
srcs = ["wrap_function.py"],
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 6117d8a..694b05c 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -80,6 +80,8 @@
return tensor_shape.as_shape(value).as_proto()
elif attr_type == [pywrap_tensorflow.TF_ATTR_SHAPE]:
return [tensor_shape.as_shape(v).as_proto() for v in value]
+ elif isinstance(value, str):
+ return value.encode()
return value
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index fd9be06..2318414 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -44,6 +44,7 @@
# Note that we do not protect this with a lock and instead rely on python's GIL
# and the idempotent nature of writes to provide thread safety.
_device_parsing_cache = {}
+_starting_device_spec = pydev.DeviceSpec.from_string("")
_MAXINT32 = 2**31 - 1
@@ -135,26 +136,52 @@
def __init__(self, config=None):
super(_EagerContext, self).__init__()
- self.device_spec = pydev.DeviceSpec.from_string("")
- self.device_name = self.device_spec.to_string()
+ self.device_spec = _starting_device_spec
+ self.device_name = ""
self.mode = default_execution_mode
self.is_eager = default_execution_mode == EAGER_MODE
self.scope_name = ""
self.recording_summaries = False
self.summary_writer_resource = None
self.scalar_cache = {}
- self.ones_rank_cache = _EagerTensorCache()
- self.zeros_cache = _EagerTensorCache()
+ self._ones_rank_cache = None
+ self._zeros_cache = None
self.execution_mode = None
# Default rewriter config corresponds to turning all default grappler
# optimizations on.
- base_config = config_pb2.ConfigProto()
+ self._config = config
- if config is not None:
- base_config.MergeFrom(config)
+ self._function_call_options = None
- self.function_call_options = FunctionCallOptions(config_proto=base_config)
+ @property
+ def function_call_options(self):
+ if self._function_call_options is None:
+ base_config = config_pb2.ConfigProto()
+ if self._config is not None:
+ base_config.MergeFrom(self._config)
+ self._config = None
+ self._function_call_options = FunctionCallOptions(
+ config_proto=base_config)
+
+ return self._function_call_options
+
+ @function_call_options.setter
+ def function_call_options(self, function_call_options):
+ self._function_call_options = function_call_options
+ self._config = None
+
+ @property
+ def ones_rank_cache(self):
+ if not self._ones_rank_cache:
+ self._ones_rank_cache = _EagerTensorCache()
+ return self._ones_rank_cache
+
+ @property
+ def zeros_cache(self):
+ if not self._zeros_cache:
+ self._zeros_cache = _EagerTensorCache()
+ return self._zeros_cache
ContextSwitch = collections.namedtuple(
diff --git a/tensorflow/python/eager/def_function.py b/tensorflow/python/eager/def_function.py
index 897a38e..59c6608 100644
--- a/tensorflow/python/eager/def_function.py
+++ b/tensorflow/python/eager/def_function.py
@@ -25,7 +25,6 @@
from tensorflow.python.eager import context
from tensorflow.python.eager import function as function_lib
from tensorflow.python.eager import lift_to_graph
-from tensorflow.python.eager import tape
from tensorflow.python.framework import func_graph as func_graph_module
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
@@ -57,8 +56,6 @@
constraint=None,
add_initializers_to=None,
lifted_initializer_graph=None,
- lifted_all_initializers=None,
- lifted_placeholders=None,
**unused_kwargs):
"""Creates a variable.
@@ -90,13 +87,9 @@
(which must have the same shape). Constraints are not safe to
use when doing asynchronous distributed training.
add_initializers_to: if not None and not in legacy graph mode, the
- initializer tensor will be added to this map instead of adding the
+ initializer tensor will be added to this map in addition to adding the
assignment to the function.
lifted_initializer_graph: FuncGraph to try to lift initializers to.
- lifted_all_initializers: list with one boolean element, which will be
- set to False if we cannot lift this initializer to the above graph.
- lifted_placeholders: placeholders for resource handles lifted out of
- this graph.
Raises:
ValueError: If the initial value is not specified, or does not have a
@@ -174,7 +167,6 @@
with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
self._initializer_op = resource_variable_ops.assign_variable_op(
self._handle, lifted_initializer, name=n)
- assign = self._initializer_op
with ops.name_scope("Read"), ops.colocate_with(self._handle):
# Manually assign reads to the handle's device to avoid log
# messages.
@@ -185,32 +177,21 @@
else:
if add_initializers_to is not None:
add_initializers_to[self] = initial_value
- assign = None
- else:
- def assign_fn():
- with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
- resource_variable_ops.assign_variable_op(
- self._handle,
- initial_value,
- name=n)
- # Returning values to keep tf.cond happy.
- return ops.convert_to_tensor(1)
- def not_assign_fn():
- return ops.convert_to_tensor(0)
- # Note: this cond is always guaranteed to run because we're inside a
- # defun which will insert automatic control dependencies.
- assign = control_flow_ops.cond(
- resource_variable_ops.var_is_initialized_op(self._handle),
- not_assign_fn, assign_fn)
- if lifted_initializer_graph is not None and assign is not None:
- try:
- handle_placeholder = ops.convert_to_tensor(self._handle)
- op_map = lift_to_graph.lift_to_graph(
- assign, lifted_initializer_graph,
- sources=[handle_placeholder])
- lifted_placeholders.append((self._handle, op_map[handle_placeholder]))
- except ValueError:
- lifted_all_initializers[0] = False
+ def assign_fn():
+ with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
+ resource_variable_ops.assign_variable_op(
+ self._handle,
+ initial_value,
+ name=n)
+ # Returning values to keep tf.cond happy.
+ return ops.convert_to_tensor(1)
+ def not_assign_fn():
+ return ops.convert_to_tensor(0)
+ # Note: this cond is always guaranteed to run because we're inside a
+ # defun which will insert automatic control dependencies.
+ control_flow_ops.cond(
+ resource_variable_ops.var_is_initialized_op(self._handle),
+ not_assign_fn, assign_fn)
# After the handle has been created, set up a way to clean it up when
# executing eagerly. We'll hold the only reference to the deleter, so that
@@ -340,16 +321,12 @@
created_variables = []
lifted_initializer_graph = func_graph_module.FuncGraph("initializer")
- lifted_all_initializers = [True]
- lifted_placeholders = []
def variable_capturing_scope(unused_next_creator, **kwds):
"""Creates UnliftedInitializerVariables and saves references to them."""
v = UnliftedInitializerVariable(
add_initializers_to=add_initializers_to,
- lifted_initializer_graph=lifted_initializer_graph,
- lifted_all_initializers=lifted_all_initializers,
- lifted_placeholders=lifted_placeholders, **kwds)
+ lifted_initializer_graph=lifted_initializer_graph, **kwds)
created_variables.append(weakref.ref(v))
return v
@@ -359,11 +336,9 @@
# Force the definition of the function for these arguments
self._lifted_initializer_graph = lifted_initializer_graph
self._graph_deleter = FunctionDeleter(self._lifted_initializer_graph)
- self._lifted_placeholders = lifted_placeholders
self._concrete_stateful_fn = (
self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
*args, **kwds))
- self._lifted_all_initializers = lifted_all_initializers[0]
def invalid_creator_scope(*unused_args, **unused_kwds):
"""Disables variable creation."""
@@ -390,21 +365,22 @@
return results
# This is the first call of __call__, so we have to initialize.
- self._initialize(args, kwds)
- if self._lifted_all_initializers and self._lifted_placeholders:
- with ops.init_scope():
- handles, placeholders = zip(*self._lifted_placeholders)
- if context.executing_eagerly():
- lifted_fn = function_lib._EagerDefinedFunction( # pylint: disable=protected-access
- "initializer" + str(ops.uid()),
- self._lifted_initializer_graph,
- placeholders, [], {})
- with tape.stop_recording():
- lifted_fn.call(context.context(), list(handles))
- return self._stateless_fn(*args, **kwds)
- canon_args, canon_kwds = self._canonicalize_function_inputs(args, kwds)
-
- if not self._created_variables:
+ initializer_map = {}
+ self._initialize(args, kwds, add_initializers_to=initializer_map)
+ if self._created_variables:
+ try:
+ # Attempt to initialize variables eagerly and without conds by lifting
+ # out initialization graphs. This is the only initialization strategy
+ # compatible with XLA at the moment.
+ self._initialize_uninitialized_variables(initializer_map)
+ except lift_to_graph.UnliftableError:
+ pass # Fall through to cond-based initialization.
+ else:
+ # Lifting succeeded, so variables are initialized and we can run the
+ # stateless function.
+ return self._stateless_fn(*args, **kwds)
+ else:
+ canon_args, canon_kwds = self._canonicalize_function_inputs(args, kwds)
# If we did not create any variables the trace we have is good enough.
return self._concrete_stateful_fn._filtered_call(canon_args, canon_kwds) # pylint: disable=protected-access
@@ -459,6 +435,9 @@
functools.partial(self._concrete_stateful_fn._filtered_call, # pylint: disable=protected-access
inner_args, inner_kwds))
+ # We've created variables and are unable to lift the initialization graphs,
+ # so we fall back to initializing with conds while running the function.
+ canon_args, canon_kwds = self._canonicalize_function_inputs(args, kwds)
return function_lib.defun(fn_with_cond)(*canon_args, **canon_kwds)
@property
@@ -474,6 +453,23 @@
def function_spec(self):
return self._function_spec
+ def _initialize_uninitialized_variables(self, initializer_map):
+ """Make and call a `ConcreteFunction` which initializes variables."""
+
+ # Note: using defun here avoids an infinite recursion.
+ @function_lib.defun
+ def initialize_variables():
+ for v, init in initializer_map.items():
+ with ops.init_scope():
+ if resource_variable_ops.var_is_initialized_op(v.handle):
+ # Ignore variables which are already initialized at trace time.
+ continue
+ v.assign(lift_to_graph.lift_to_graph(
+ init, ops.get_default_graph())[init])
+
+ with ops.init_scope():
+ return initialize_variables.get_concrete_function()()
+
def get_initialization_function(self, *args, **kwargs):
"""Returns a `ConcreteFunction` which initializes this function's variables.
@@ -482,6 +478,9 @@
function which does not depend on the concrete values of the inputs to this
function.
+ Note that running this function will overwrite any values currently assigned
+ to variables, for example restores from a checkpoint.
+
Args:
*args: arguments to the underlying python callable.
**kwargs: keyword arguments to the python callable.
@@ -624,9 +623,10 @@
Raises:
ValueError: if this object has not yet been called on concrete values.
"""
- assert context.executing_eagerly()
if self._stateful_fn is None:
- self.get_initialization_function(*args, **kwargs)()
+ initializer_map = {}
+ self._initialize(args, kwargs, add_initializers_to=initializer_map)
+ self._initialize_uninitialized_variables(initializer_map)
if self._created_variables:
# In this case we have created variables on the first call, so we run the
diff --git a/tensorflow/python/eager/def_function_test.py b/tensorflow/python/eager/def_function_test.py
index 912198d..b49b165 100644
--- a/tensorflow/python/eager/def_function_test.py
+++ b/tensorflow/python/eager/def_function_test.py
@@ -21,7 +21,9 @@
import weakref
from tensorflow.python.eager import backprop
+from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
+from tensorflow.python.eager import lift_to_graph
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -30,6 +32,7 @@
from tensorflow.python.keras.engine import training
from tensorflow.python.keras.layers import core
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
@@ -208,7 +211,7 @@
state.append(variables.Variable(2.0 * x))
return state[0] * x
- with self.assertRaises(ValueError):
+ with self.assertRaises(lift_to_graph.UnliftableError):
fn(constant_op.constant(3.0))
def testMethod(self):
@@ -265,7 +268,8 @@
self.assertAllClose(4., concrete(constant_op.constant(2.)))
signature_args, _ = concrete.structured_input_signature
self.assertEqual(signature_args,
- (tensor_spec.TensorSpec(None, dtypes.float32),))
+ (tensor_spec.TensorSpec(
+ None, dtypes.float32, name='x'),))
def test_serialization_signature_cache(self):
@@ -285,10 +289,10 @@
self.assertEqual(
signatures_args,
- set(((tensor_spec.TensorSpec([1, 2], dtypes.float32),
- tensor_spec.TensorSpec([1], dtypes.float32)),
- (tensor_spec.TensorSpec([1, 3], dtypes.int32),
- tensor_spec.TensorSpec([1], dtypes.int32)))))
+ set(((tensor_spec.TensorSpec([1, 2], dtypes.float32, name='x'),
+ tensor_spec.TensorSpec([1], dtypes.float32, name='y')),
+ (tensor_spec.TensorSpec([1, 3], dtypes.int32, name='x'),
+ tensor_spec.TensorSpec([1], dtypes.int32, name='y')))))
@test_util.assert_no_garbage_created
def testFunctionReferenceCycles(self):
@@ -343,6 +347,88 @@
f()
self.assertEqual(created_variables, captured_variables)
+ def testVarAlreadyInitializedNoClobbering(self):
+ v_holder = []
+
+ @def_function.function
+ def add_var(x):
+ if not v_holder:
+ v = variables.Variable([1., 2.])
+ v_holder.append(v)
+ already_initialized = variables.Variable(3.)
+ with ops.init_scope():
+ already_initialized.assign(10.)
+ v_holder.append(already_initialized)
+ return v_holder[0] + v_holder[1] + x
+
+ add_var.get_concrete_function(constant_op.constant(2.))
+ self.assertAllClose([13., 14.], add_var(constant_op.constant(2.)))
+
+ def testSameVariableTwice(self):
+
+ v = variables.Variable(1.0)
+
+ @def_function.function
+ def add(a, b):
+ return a + b
+
+ self.assertAllEqual(add(v, v), 2.0)
+
+ def testShapeCache(self):
+ @def_function.function
+ def func(x):
+ return 2 * x
+
+ func_a = func.get_concrete_function(
+ tensor_spec.TensorSpec([None], dtypes.int32))
+ func_b = func.get_concrete_function(
+ tensor_spec.TensorSpec([None], dtypes.int32))
+
+ self.assertIs(func_a, func_b)
+
+ def testInitializationInNestedCall(self):
+ v_holder = []
+
+ @def_function.function
+ def add_var(x):
+ if not v_holder:
+ v = variables.Variable([1., 2.])
+ v_holder.append(v)
+ already_initialized = variables.Variable(3.)
+ with ops.init_scope():
+ already_initialized.assign(10.)
+ v_holder.append(already_initialized)
+ return v_holder[0] + v_holder[1] + x
+
+ @def_function.function
+ def wrapper(x):
+ return add_var(x)
+
+ self.assertAllClose([13., 14.], wrapper(constant_op.constant(2.)))
+ v_holder[1].assign(11.)
+ self.assertAllClose([14., 15.], wrapper(constant_op.constant(2.)))
+
+ def testDeviceAnnotationRespected(self):
+ if not context.num_gpus():
+ self.skipTest("Needs multiple devices")
+
+ a = []
+
+ @def_function.function()
+ def create_variable():
+ with ops.init_scope():
+ initial_value = random_ops.random_uniform(
+ (2, 2), maxval=1000000, dtype=dtypes.int64)
+
+ if not a:
+ with ops.device("CPU:0"):
+ a.append(resource_variable_ops.ResourceVariable(initial_value))
+
+ return a[0].read_value()
+
+ created_variable_read = create_variable()
+ self.assertRegexpMatches(created_variable_read.device, "CPU")
+
if __name__ == '__main__':
ops.enable_eager_execution()
diff --git a/tensorflow/python/eager/def_function_xla_test.py b/tensorflow/python/eager/def_function_xla_test.py
new file mode 100644
index 0000000..9115d8a
--- /dev/null
+++ b/tensorflow/python/eager/def_function_xla_test.py
@@ -0,0 +1,49 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.compiler.tests import xla_test
+from tensorflow.python.eager import def_function
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class DefFunctionTests(xla_test.XLATestCase):
+
+ def testVarInitializedInFunction(self):
+ with self.test_scope():
+ v_holder = []
+
+ @def_function.function
+ def add_var(x):
+ if not v_holder:
+ v = variables.Variable([1., 2.])
+ v_holder.append(v)
+ already_initialized = variables.Variable(3.)
+ with ops.init_scope():
+ already_initialized.assign(10.)
+ v_holder.append(already_initialized)
+ return v_holder[0] + v_holder[1] + x
+
+ self.assertAllClose([13., 14.], add_var(constant_op.constant(2.)))
+
+
+if __name__ == "__main__":
+ ops.enable_eager_execution()
+ test.main()
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 00932c4..04f1999 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -22,7 +22,6 @@
import collections
import functools
import re
-import sys
import threading
import types as types_lib
import weakref
@@ -48,7 +47,7 @@
from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import functional_ops
-from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import gradients_util
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat
@@ -58,8 +57,6 @@
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
-# This is to avoid a circular dependency with gradients_impl
-gradients_impl._function = sys.modules[__name__] # pylint: disable=protected-access
FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"
BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"
@@ -478,11 +475,17 @@
tape.variables_accessed(self._func_graph.variables)
tensor_inputs = []
+ variables_used = set([])
for i, arg in enumerate(args):
if isinstance(arg, resource_variable_ops.ResourceVariable):
+ # We can pass a variable more than once, and in this case we need to
+ # pass its handle only once.
+ if arg.handle in variables_used:
+ continue
if arg.trainable:
tape.variable_accessed(arg)
tensor_inputs.append(arg.handle)
+ variables_used.add(arg.handle)
elif isinstance(arg, ops.Tensor):
tensor_inputs.append(arg)
elif (self._signature is not None and
@@ -664,12 +667,12 @@
_backward_name(self._func_graph.name))
forward_function_name = _forward_name(self._func_graph.name)
outputs = [x for x in self._func_graph.outputs
- if gradients_impl.IsTrainable(x)]
+ if gradients_util.IsTrainable(x)]
with backwards_graph.as_default():
gradients_wrt_outputs = [
graph_placeholder(x.dtype, x.shape) for x in outputs
]
- gradients_wrt_inputs = gradients_impl._GradientsHelper( # pylint: disable=protected-access
+ gradients_wrt_inputs = gradients_util._GradientsHelper( # pylint: disable=protected-access
outputs,
self._func_graph.inputs,
grad_ys=gradients_wrt_outputs,
@@ -738,7 +741,7 @@
# the forward graph function so that we can compute its gradient.
real_outputs = outputs[:self._num_outputs]
skip_positions = [i for i, t in enumerate(real_outputs)
- if not gradients_impl.IsTrainable(t)]
+ if not gradients_util.IsTrainable(t)]
side_outputs = outputs[self._num_outputs:]
def backward_function(*args):
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index c8f7ecb..871e4e2 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -466,6 +466,22 @@
value = tensor_init()
self.assertAllEqual(value, 2.0)
+ @test_util.run_in_graph_and_eager_modes
+ def testGetConcreteFunctionCreatesVariables(self):
+
+ v_holder = []
+
+ @def_function.function
+ def tensor_init():
+ if not v_holder:
+ v_holder.append(variables.Variable(5.))
+ return v_holder[0].read_value()
+
+ concrete = tensor_init.get_concrete_function()
+ self.evaluate(variables.global_variables_initializer())
+ self.assertAllEqual(5., self.evaluate(concrete()))
+ self.assertAllEqual(5., self.evaluate(tensor_init()))
+
def testDefunShapeInferenceWithCapturedResourceVariable(self):
v = resource_variable_ops.ResourceVariable([[1, 2], [3, 4]])
@@ -807,8 +823,9 @@
return None
with self.assertRaisesRegexp(
- errors.InvalidArgumentError, 'Could not colocate node with its '
- 'resource and reference inputs.*'):
+ errors.InvalidArgumentError,
+ 'Cannot place the graph because a reference or resource edge connects '
+ 'colocation groups with incompatible assigned devices'):
if not context.executing_eagerly():
self.evaluate(variables.global_variables_initializer())
self.evaluate(resource_apply_adam())
@@ -1057,7 +1074,7 @@
def func():
return constant_op.constant(0)
- defined = function.defun(func)
+ defined = def_function.function(func)
with ops.device('cpu:0'):
cpu_graph_function = defined.get_concrete_function()
@@ -1324,6 +1341,7 @@
'tuple or a list.*'):
function.defun(foo, input_signature=signature)
+ @test_util.run_in_graph_and_eager_modes
def testInputsIncompatibleWithSignatureRaisesError(self):
def foo(a):
diff --git a/tensorflow/python/eager/lift_to_graph.py b/tensorflow/python/eager/lift_to_graph.py
index 2e9d24f..ad62e6d 100644
--- a/tensorflow/python/eager/lift_to_graph.py
+++ b/tensorflow/python/eager/lift_to_graph.py
@@ -35,6 +35,11 @@
return op_or_tensor
+class UnliftableError(Exception):
+ """Raised if a Tensor cannot be lifted from the graph."""
+ pass
+
+
def lift_to_graph(init_tensor, graph, sources=None):
"""Copies the tensor and all its inputs recursively to the outer graph."""
# Check that the initializer does not depend on any placeholders.
@@ -52,7 +57,7 @@
# and placeholders the user might directly use to initialize
# variables.
if op.type == "Placeholder":
- raise ValueError(
+ raise UnliftableError(
"Unable to lift tensor", init_tensor,
"because it depends transitively on placeholder ", op)
for inp in _graph_inputs(op):
diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h
index 8d6f212..63440c0 100755
--- a/tensorflow/python/eager/pywrap_tfe.h
+++ b/tensorflow/python/eager/pywrap_tfe.h
@@ -234,4 +234,6 @@
// for the defun function cache.
PyObject* TFE_Py_EncodeArg(PyObject*);
+void TFE_Py_EnableInteractivePythonLogging();
+
#endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index da1bb24..eb2f28d 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -3033,3 +3033,36 @@
return result.ToPyTuple();
}
+
+// A method prints incoming messages directly to Python's
+// stdout using Python's C API. This is necessary in Jupyter notebooks
+// and colabs where messages to the C stdout don't go to the notebook
+// cell outputs, but calls to Python's stdout do.
+void PrintToPythonStdout(const char* msg) {
+ if (Py_IsInitialized()) {
+ PyGILState_STATE py_threadstate;
+ py_threadstate = PyGILState_Ensure();
+
+ string string_msg = msg;
+ // PySys_WriteStdout truncates strings over 1000 bytes, so
+ // we write the message in chunks small enough to not be truncated.
+ int CHUNK_SIZE = 900;
+ auto len = string_msg.length();
+ for (int i = 0; i < len; i += CHUNK_SIZE) {
+ PySys_WriteStdout("%s", string_msg.substr(i, CHUNK_SIZE).c_str());
+ }
+ PySys_WriteStdout("\n");
+
+ PyGILState_Release(py_threadstate);
+ }
+}
+
+// Register PrintToPythonStdout as a log listener, to allow
+// printing in colabs and jupyter notebooks to work.
+void TFE_Py_EnableInteractivePythonLogging() {
+ static bool enabled_interactive_logging = false;
+ if (!enabled_interactive_logging) {
+ enabled_interactive_logging = true;
+ TF_RegisterLogListener(PrintToPythonStdout);
+ }
+}
diff --git a/tensorflow/python/eager/tensor_test.py b/tensorflow/python/eager/tensor_test.py
index 0ee2ff6..0d8845b 100644
--- a/tensorflow/python/eager/tensor_test.py
+++ b/tensorflow/python/eager/tensor_test.py
@@ -339,6 +339,24 @@
def testConvertToTensorAllowsOverflow(self):
_ = ops.convert_to_tensor(123456789, dtype=dtypes.uint8)
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ @test_util.run_in_graph_and_eager_modes
+ def testConvertToTensorNumpyZeroDim(self):
+ for np_type, dtype in [(np.int32, dtypes.int32),
+ (np.half, dtypes.half),
+ (np.float32, dtypes.float32)]:
+ x = ops.convert_to_tensor([np.array(65, dtype=np_type),
+ np.array(16, dtype=np_type)])
+ self.assertEqual(x.dtype, dtype)
+ self.assertAllEqual(x, [65, 16])
+
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ @test_util.run_in_graph_and_eager_modes
+ def testConvertToTensorNumpyScalar(self):
+ x = ops.convert_to_tensor([np.asscalar(np.array(321, dtype=np.int)),
+ np.asscalar(np.array(16, dtype=np.int))])
+ self.assertAllEqual(x, [321, 16])
+
def testEagerTensorError(self):
with self.assertRaisesRegexp(
TypeError,
@@ -347,7 +365,6 @@
_ = ops.convert_to_tensor(1., dtype=dtypes.int32)
-
class TFETensorUtilTest(test_util.TensorFlowTestCase):
def testListOfThree(self):
diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py
index a942456..5ebbaab 100644
--- a/tensorflow/python/feature_column/feature_column_v2.py
+++ b/tensorflow/python/feature_column/feature_column_v2.py
@@ -169,7 +169,7 @@
from tensorflow.python.util.tf_export import tf_export
-_FEATURE_COLUMN_DEPRECATION_DATE = '2018-11-30'
+_FEATURE_COLUMN_DEPRECATION_DATE = None
_FEATURE_COLUMN_DEPRECATION = ('The old _FeatureColumn APIs are being '
'deprecated. Please use the new FeatureColumn '
'APIs instead.')
@@ -380,7 +380,7 @@
return array_ops.concat(output_tensors, -1)
-@keras_export('keras.layers.DenseFeatures', v1=[])
+@keras_export('keras.layers.DenseFeatures')
class DenseFeatures(_BaseFeaturesLayer):
"""A layer that produces a dense `Tensor` based on given `feature_columns`.
@@ -4022,13 +4022,9 @@
def transform_feature(self, transformation_cache, state_manager):
"""Applies weights to tensor generated from `categorical_column`'."""
- print('WeightedCategoricalColumn.transform_feature: ', self.name)
- print('Weight feature key: ', self.weight_feature_key)
weight_tensor = transformation_cache.get(self.weight_feature_key,
state_manager)
- print('Weight tensor before: ', weight_tensor)
weight_tensor = self._transform_weight_tensor(weight_tensor)
- print('Weight tensor after: ', weight_tensor)
return (transformation_cache.get(self.categorical_column, state_manager),
weight_tensor)
@@ -4042,9 +4038,7 @@
def get_sparse_tensors(self, transformation_cache, state_manager):
"""See `CategoricalColumn` base class."""
- print('WeightedCategoricalColumn.get_sparse_tensors: ', self.name)
tensors = transformation_cache.get(self, state_manager)
- print('tensors[1]: ', tensors[1])
return CategoricalColumn.IdWeightPair(tensors[0], tensors[1])
@deprecation.deprecated(_FEATURE_COLUMN_DEPRECATION_DATE,
diff --git a/tensorflow/python/framework/auto_control_deps.py b/tensorflow/python/framework/auto_control_deps.py
index da76a84..6210010 100644
--- a/tensorflow/python/framework/auto_control_deps.py
+++ b/tensorflow/python/framework/auto_control_deps.py
@@ -249,23 +249,28 @@
ops_which_must_run = set([op])
continue
found_resource = False
- for inp in op.inputs:
- if inp.dtype == dtypes_module.resource:
- found_resource = True
- # Deal with switches, finally.
- if inp.op.type == "Switch":
- self._process_switch(inp.op, ops_which_must_run,
- last_op_using_resource_tensor,
- merge_for_resource)
- # Ensure uses of resources are serialized
- if inp in last_op_using_resource_tensor:
- if (last_op_using_resource_tensor[inp]._control_flow_context # pylint: disable=protected-access
- is op._control_flow_context): # pylint: disable=protected-access
- control_inputs.add(last_op_using_resource_tensor[inp])
- # Ensure merges happen after the closing of a cond block
- if inp in merge_for_resource:
- merge_for_resource[inp]._add_control_input(op) # pylint: disable=protected-access
- last_op_using_resource_tensor[inp] = op
+ # Check for any resource inputs. If we find any, we update control_inputs
+ # and last_op_using_resource_tensor. Note that we dedup op.inputs in case
+ # op receives the same resource tensor twice as input, which would result
+ # in op getting a control dependency on itself.
+ for inp in set(op.inputs):
+ if inp.dtype != dtypes_module.resource:
+ continue
+ found_resource = True
+ # Deal with switches, finally.
+ if inp.op.type == "Switch":
+ self._process_switch(inp.op, ops_which_must_run,
+ last_op_using_resource_tensor,
+ merge_for_resource)
+ # Ensure uses of resources are serialized
+ if inp in last_op_using_resource_tensor:
+ if (last_op_using_resource_tensor[inp]._control_flow_context # pylint: disable=protected-access
+ is op._control_flow_context): # pylint: disable=protected-access
+ control_inputs.add(last_op_using_resource_tensor[inp])
+ # Ensure merges happen after the closing of a cond block
+ if inp in merge_for_resource:
+ merge_for_resource[inp]._add_control_input(op) # pylint: disable=protected-access
+ last_op_using_resource_tensor[inp] = op
if (op.op_def.is_stateful and op.type not in ASYNC_STATEFUL_OPS
and not found_resource and op._control_flow_context is None): # pylint: disable=protected-access
if None in last_op_using_resource_tensor:
diff --git a/tensorflow/python/framework/auto_control_deps_test.py b/tensorflow/python/framework/auto_control_deps_test.py
index d81adef..2c25ab1 100644
--- a/tensorflow/python/framework/auto_control_deps_test.py
+++ b/tensorflow/python/framework/auto_control_deps_test.py
@@ -19,6 +19,7 @@
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
+from tensorflow.python.eager import def_function
from tensorflow.python.eager import function
from tensorflow.python.framework import auto_control_deps as acd
from tensorflow.python.framework import constant_op
@@ -281,6 +282,20 @@
train()
self.assertEqual(v.numpy(), -1.0)
+ def testRepeatedResourceInput(self):
+ var = resource_variable_ops.ResourceVariable(1.0)
+
+ @def_function.function
+ def inner(var1, var2):
+ return (resource_variable_ops.read_variable_op(var1, dtypes.float32) +
+ resource_variable_ops.read_variable_op(var2, dtypes.float32))
+
+ @def_function.function
+ def outer():
+ return inner(var.handle, var.handle)
+
+ self.assertEqual(self.evaluate(outer()), 2.0)
+
if __name__ == '__main__':
ops.enable_eager_execution()
diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py
index 9d643e0..6638be2 100644
--- a/tensorflow/python/framework/dtypes.py
+++ b/tensorflow/python/framework/dtypes.py
@@ -282,9 +282,6 @@
"""Returns the string name for this `DType`."""
return _TYPE_TO_STRING[self._type_enum]
- def __int__(self):
- return self._type_enum
-
def __str__(self):
return "<dtype: %r>" % self.name
diff --git a/tensorflow/python/framework/errors_impl.py b/tensorflow/python/framework/errors_impl.py
index 922b9e2..c473dfe 100644
--- a/tensorflow/python/framework/errors_impl.py
+++ b/tensorflow/python/framework/errors_impl.py
@@ -511,7 +511,11 @@
@tf_export("errors.error_code_from_exception_type")
def error_code_from_exception_type(cls):
- return _EXCEPTION_CLASS_TO_CODE[cls]
+ try:
+ return _EXCEPTION_CLASS_TO_CODE[cls]
+ except KeyError:
+ warnings.warn("Unknown class exception")
+ return UnknownError(None, None, "Unknown class exception", None)
def _make_specific_exception(node_def, op, message, error_code):
diff --git a/tensorflow/python/framework/errors_test.py b/tensorflow/python/framework/errors_test.py
index 574b126..c044202 100644
--- a/tensorflow/python/framework/errors_test.py
+++ b/tensorflow/python/framework/errors_test.py
@@ -70,6 +70,10 @@
isinstance(
errors_impl._make_specific_exception(None, None, None,
error_code), exc_type))
+ # error_code_from_exception_type and exception_type_from_error_code should
+ # be consistent with operation result.
+ self.assertEqual(error_code,
+ errors_impl.error_code_from_exception_type(exc_type))
# pylint: enable=protected-access
def testKnownErrorClassForEachErrorCodeInProto(self):
@@ -98,6 +102,14 @@
self.assertTrue("Unknown error code: 37" in str(w[0].message))
self.assertTrue(isinstance(exc, errors_impl.OpError))
+ with warnings.catch_warnings(record=True) as w:
+ # pylint: disable=protected-access
+ exc = errors_impl.error_code_from_exception_type("Unknown")
+ # pylint: enable=protected-access
+ self.assertEqual(1, len(w))
+ self.assertTrue("Unknown class exception" in str(w[0].message))
+ self.assertTrue(isinstance(exc, errors_impl.OpError))
+
def testStatusDoesNotLeak(self):
try:
with errors.raise_exception_on_not_ok_status() as status:
diff --git a/tensorflow/python/framework/func_graph.py b/tensorflow/python/framework/func_graph.py
index fe96848..3f81466 100644
--- a/tensorflow/python/framework/func_graph.py
+++ b/tensorflow/python/framework/func_graph.py
@@ -556,7 +556,9 @@
# Even if an argument variable was not used in the function, we've
# already manually captured the resource Tensor when creating argument
# placeholders.
- resource_placeholder = func_graph.captures.pop(arg.handle)
+ resource_placeholder = func_graph.captures.pop(arg.handle, None)
+ if resource_placeholder is None:
+ continue
arg_variables.add(arg)
inputs.append(resource_placeholder)
elif isinstance(arg, ops.Tensor):
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index 3d5a5fe..7543376 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -284,7 +284,6 @@
out, = sess.run(dlogits, {logits: x, labels: y})
self.assertAllClose(out, np.exp(prob - y))
- @test_util.disable_xla("This test never passed for XLA")
def testCustomGradientError(self):
dtype = dtypes.float32
diff --git a/tensorflow/python/framework/graph_util_impl.py b/tensorflow/python/framework/graph_util_impl.py
index 1b61ac9..a46fccc 100644
--- a/tensorflow/python/framework/graph_util_impl.py
+++ b/tensorflow/python/framework/graph_util_impl.py
@@ -143,13 +143,14 @@
# Breadth first search to find all the nodes that we should keep.
next_to_visit = target_nodes[:]
while next_to_visit:
- n = next_to_visit[0]
+ node = next_to_visit[0]
del next_to_visit[0]
- if n in nodes_to_keep:
+ if node in nodes_to_keep:
# Already visited this node.
continue
- nodes_to_keep.add(n)
- next_to_visit += name_to_input_name[n]
+ nodes_to_keep.add(node)
+ if node in name_to_input_name:
+ next_to_visit += name_to_input_name[node]
return nodes_to_keep
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index e74b43f..6e64622 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -95,9 +95,12 @@
lineno = -1
self.display_name = "%s<%s, %d>" % (func_name, fname, lineno)
+ self.raw_string = None
+
self.function = self._device_name_or_function
if not (self._device_name_or_function is None or
callable(self._device_name_or_function)):
+ self.raw_string = self._device_name_or_function
self.function = pydev.merge_device(self._device_name_or_function)
@@ -3048,9 +3051,6 @@
# being called inside function definitions behave as if they were seeing the
# actual outside graph).
self._graph_key = "grap-key-%d/" % (uid(),)
- # A string with the last reduction method passed to
- # losses.compute_weighted_loss(), or None.
- self._last_loss_reduction = None
self._container = ""
self._registered_ops = op_def_registry.get_registered_ops()
# Set to True if this graph is being built in an
@@ -5547,6 +5547,8 @@
try:
with outer_context(), name_scope(scope), control_dependencies(
None), tape.stop_recording():
+ context_manager = NullContextmanager
+ context_manager_input = None
if not context.executing_eagerly():
# The device stack is preserved when lifting into a graph. Eager
# execution doesn't implement device stacks and in particular it
@@ -5555,7 +5557,21 @@
outer_graph = get_default_graph()
outer_device_stack = outer_graph._device_function_stack # pylint: disable=protected-access
outer_graph._device_function_stack = innermost_nonempty_device_stack # pylint: disable=protected-access
- yield
+ elif innermost_nonempty_device_stack is not None:
+ for device_spec in innermost_nonempty_device_stack.peek_objs():
+ if device_spec.function is None:
+ break
+ if device_spec.raw_string:
+ context_manager = context.device
+ context_manager_input = device_spec.raw_string
+ break
+ # It is currently not possible to have a device function in V2,
+ # but in V1 we are unable to apply device functions in eager mode.
+ # This means that we will silently skip some of the entries on the
+ # device stack in V1 + eager mode.
+
+ with context_manager(context_manager_input):
+ yield
finally:
# If an exception is raised here it may be hiding a related exception in
# try-block (just above).
diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc
index 0f30438..b0b9ce7 100644
--- a/tensorflow/python/framework/python_op_gen.cc
+++ b/tensorflow/python/framework/python_op_gen.cc
@@ -651,7 +651,7 @@
strings::StrAppend(&result_, " \"\"\"\n");
strings::StrAppend(&result_,
- " _ctx = _context._context\n"
+ " _ctx = _context._context or _context.context()\n"
" if _ctx is not None and _ctx._eager_context.is_eager:",
"\n");
if (eager_not_allowed_error.empty()) {
@@ -930,42 +930,45 @@
string function_call_parameters;
string inputs;
string attrs;
+
std::map<string, string> renames;
- for (const auto& input_arg : api_def_.in_arg()) {
- renames.insert({input_arg.name(), input_arg.rename_to()});
- }
- for (const auto& attr : api_def_.attr()) {
- renames.insert({attr.name(), attr.rename_to()});
+ for (const auto& param_names : param_names_) {
+ renames.insert({param_names.GetName(), param_names.GetRenameTo()});
}
for (const auto& input_arg : op_def_.input_arg()) {
+ const string input_arg_name =
+ python_op_gen_internal::AvoidPythonReserved(input_arg.name());
if (!raw_parameters.empty()) strings::StrAppend(&raw_parameters, ", ");
- strings::StrAppend(&raw_parameters, input_arg.name());
+ strings::StrAppend(&raw_parameters, input_arg_name);
if (!inputs.empty()) strings::StrAppend(&inputs, ", ");
- strings::StrAppend(&inputs, input_arg.name());
+ strings::StrAppend(&inputs, input_arg_name);
if (!function_call_parameters.empty()) {
strings::StrAppend(&function_call_parameters, ", ");
}
strings::StrAppend(&function_call_parameters, renames[input_arg.name()],
- "=", input_arg.name());
+ "=", input_arg_name);
}
for (const auto& attr : op_def_.attr()) {
if (inferred_attrs_.find(attr.name()) != inferred_attrs_.end()) continue;
+ const string attr_name =
+ python_op_gen_internal::AvoidPythonReserved(attr.name());
+
if (!raw_parameters.empty()) strings::StrAppend(&raw_parameters, ", ");
- strings::StrAppend(&raw_parameters, attr.name());
+ strings::StrAppend(&raw_parameters, attr_name);
if (!attrs.empty()) strings::StrAppend(&attrs, ", ");
- strings::StrAppend(&attrs, "\"", attr.name(), "\", ", attr.name());
+ strings::StrAppend(&attrs, "\"", attr_name, "\", ", attr_name);
if (!function_call_parameters.empty()) {
strings::StrAppend(&function_call_parameters, ", ");
}
strings::StrAppend(&function_call_parameters, renames[attr.name()], "=",
- attr.name());
+ attr_name);
}
const string raw_function_name =
diff --git a/tensorflow/python/framework/registry_test.py b/tensorflow/python/framework/registry_test.py
index 1a0d3f2..5adf12f 100644
--- a/tensorflow/python/framework/registry_test.py
+++ b/tensorflow/python/framework/registry_test.py
@@ -19,28 +19,33 @@
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
+
from tensorflow.python.framework import registry
from tensorflow.python.platform import test
-class RegistryTest(test.TestCase):
+def bar():
+ pass
+
+
+class RegistryTest(test.TestCase, parameterized.TestCase):
class Foo(object):
pass
- def testRegisterClass(self):
- myreg = registry.Registry('testfoo')
+ # Test the registry basics on both classes (Foo) and functions (bar).
+ @parameterized.parameters([Foo, bar])
+ def testRegistryBasics(self, candidate):
+ myreg = registry.Registry('testRegistry')
with self.assertRaises(LookupError):
- myreg.lookup('Foo')
- myreg.register(RegistryTest.Foo, 'Foo')
- assert myreg.lookup('Foo') == RegistryTest.Foo
-
- def testRegisterFunction(self):
- myreg = registry.Registry('testbar')
- with self.assertRaises(LookupError):
- myreg.lookup('Bar')
- myreg.register(bar, 'Bar')
- assert myreg.lookup('Bar') == bar
+ myreg.lookup('testKey')
+ myreg.register(candidate)
+ self.assertEqual(myreg.lookup(candidate.__name__), candidate)
+ myreg.register(candidate, 'testKey')
+ self.assertEqual(myreg.lookup('testKey'), candidate)
+ self.assertEqual(
+ sorted(myreg.list()), sorted(['testKey', candidate.__name__]))
def testDuplicate(self):
myreg = registry.Registry('testbar')
@@ -51,9 +56,5 @@
myreg.register(bar, 'Bar')
-def bar():
- pass
-
-
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/framework/tensor_spec.py b/tensorflow/python/framework/tensor_spec.py
index c44636e..2e847c7 100644
--- a/tensorflow/python/framework/tensor_spec.py
+++ b/tensorflow/python/framework/tensor_spec.py
@@ -108,7 +108,9 @@
return hash((self._shape_tuple, self.dtype))
def __eq__(self, other):
- return self.shape == other.shape and self.dtype == other.dtype
+ return (self._shape_tuple == other._shape_tuple # pylint: disable=protected-access
+ and self.dtype == other.dtype
+ and self._name == other._name) # pylint: disable=protected-access
def __ne__(self, other):
return not self == other
diff --git a/tensorflow/python/framework/tensor_util.py b/tensorflow/python/framework/tensor_util.py
index 7de6653..21ded1a 100644
--- a/tensorflow/python/framework/tensor_util.py
+++ b/tensorflow/python/framework/tensor_util.py
@@ -22,6 +22,7 @@
from tensorflow.core.framework import tensor_pb2
from tensorflow.core.framework import tensor_shape_pb2
+from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.util import compat
@@ -932,13 +933,15 @@
return ret
+@tf_export("is_tensor")
def is_tensor(x): # pylint: disable=invalid-name
"""Check whether `x` is of tensor type.
- Check whether an object is a tensor. This check is equivalent to calling
- `isinstance(x, (tf.Tensor, tf.SparseTensor, tf.Variable))` and also checks
- if all the component variables of a MirroredVariable or a ReplicaLocalVariable
- are tensors.
+ Check whether an object is a tensor or a composite tensor. This check is
+ equivalent to calling
+ `isinstance(x, (tf.Tensor, tf.SparseTensor, tf.RaggedTensor, tf.Variable))`
+ and also checks if all the component variables of a MirroredVariable or a
+ ReplicaLocalVariable are tensors.
Args:
x: A python object to check.
@@ -947,4 +950,5 @@
`True` if `x` is a tensor, `False` if not.
"""
return (isinstance(x, ops._TensorLike) or ops.is_dense_tensor_like(x) or # pylint: disable=protected-access
+ isinstance(x, composite_tensor.CompositeTensor) or
(hasattr(x, "is_tensor_like") and x.is_tensor_like))
diff --git a/tensorflow/python/framework/tensor_util_test.py b/tensorflow/python/framework/tensor_util_test.py
index cdacdfa..e73df39 100644
--- a/tensorflow/python/framework/tensor_util_test.py
+++ b/tensorflow/python/framework/tensor_util_test.py
@@ -773,6 +773,16 @@
self.assertAllClose(np.array([10, 20, 30], dtype=np.int64), a)
+class IsTensorTest(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testConstantTensor(self):
+ np_val = np.random.rand(3).astype(np.int32)
+ tf_val = constant_op.constant(np_val)
+ self.assertFalse(tensor_util.is_tensor(np_val))
+ self.assertTrue(tensor_util.is_tensor(tf_val))
+
+
class ConstantValueTest(test.TestCase):
def testConstant(self):
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index d27572f..1d26783 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -1460,7 +1460,7 @@
value = getattr(cls, name)
if callable(value) and name.startswith(
"test") and not name == "test_session":
- setattr(cls, name, base_decorator(value))
+ setattr(cls, name, base_decorator(description)(value))
return cls
return disable_all_impl
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index ad895d2..d904193 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -91,6 +91,7 @@
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:ctc_ops",
+ "//tensorflow/python:distribute",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework",
"//tensorflow/python:framework_ops",
@@ -114,8 +115,10 @@
"//tensorflow/python:tensor_array_grad",
"//tensorflow/python:tensor_array_ops",
"//tensorflow/python:tensor_shape",
+ "//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python:variables",
+ "//tensorflow/python/distribute:distribute_coordinator",
],
)
@@ -161,6 +164,7 @@
":regularizers",
":saving",
"//tensorflow/python/data",
+ "//tensorflow/python/distribute:distribute_coordinator",
"//tensorflow/python/distribute:distribute_lib",
"//tensorflow/python/distribute:input_lib",
"//tensorflow/python/distribute:reduce_util",
@@ -288,6 +292,7 @@
srcs_version = "PY2AND3",
deps = [
":backend",
+ "//tensorflow/python/distribute:distribute_lib",
],
)
@@ -360,9 +365,13 @@
"@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:nn_ops",
],
shard_count = 12,
- tags = ["notsan"],
+ tags = [
+ "no_oss", # b/123899138
+ "notsan",
+ ],
)
tf_py_test(
@@ -374,6 +383,7 @@
"@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:nn_ops",
],
)
@@ -387,6 +397,9 @@
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
],
+ tags = [
+ "no_oss", # b/123899138
+ ],
)
tf_py_test(
@@ -404,7 +417,7 @@
tf_py_test(
name = "regularizers_test",
- size = "small",
+ size = "medium",
srcs = ["regularizers_test.py"],
additional_deps = [
":keras",
@@ -478,6 +491,19 @@
)
tf_py_test(
+ name = "metrics_correctness_test",
+ size = "medium",
+ srcs = ["metrics_correctness_test.py"],
+ additional_deps = [
+ ":keras",
+ "@absl_py//absl/testing:parameterized",
+ "//third_party/py/numpy",
+ "//tensorflow/python:client_testlib",
+ ],
+ shard_count = 4,
+)
+
+tf_py_test(
name = "applications_test",
size = "enormous",
srcs = ["applications/applications_test.py"],
@@ -796,6 +822,7 @@
],
shard_count = 4,
tags = [
+ "no_oss", # b/123899138
"noasan", # http://b/78599823
"notsan",
],
diff --git a/tensorflow/python/keras/activations.py b/tensorflow/python/keras/activations.py
index 8f10aca..a10629a 100644
--- a/tensorflow/python/keras/activations.py
+++ b/tensorflow/python/keras/activations.py
@@ -26,6 +26,19 @@
from tensorflow.python.ops import nn
from tensorflow.python.util.tf_export import keras_export
+# b/123041942
+# In TF 2.x, if the `tf.nn.softmax` is used as an activation function in Keras
+# layers, it gets serialized as 'softmax_v2' instead of 'softmax' as the
+# internal method name is returned in serialization. This results in errors in
+# model exporting and loading as Keras can't find any activation function with
+# the name of `softmax_v2`.
+
+# This dict maps the activation function name from its v2 version to its
+# canonical name.
+_TF_ACTIVATIONS_V2 = {
+ 'softmax_v2': 'softmax',
+}
+
@keras_export('keras.activations.softmax')
def softmax(x, axis=-1):
@@ -190,6 +203,8 @@
@keras_export('keras.activations.serialize')
def serialize(activation):
+ if activation.__name__ in _TF_ACTIVATIONS_V2:
+ return _TF_ACTIVATIONS_V2[activation.__name__]
return activation.__name__
diff --git a/tensorflow/python/keras/activations_test.py b/tensorflow/python/keras/activations_test.py
index 33001f4..9d219548 100644
--- a/tensorflow/python/keras/activations_test.py
+++ b/tensorflow/python/keras/activations_test.py
@@ -22,6 +22,7 @@
from tensorflow.python import keras
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import nn_ops as nn
from tensorflow.python.platform import test
@@ -46,6 +47,14 @@
fn = keras.activations.deserialize(config)
assert fn == ref_fn
+ def test_serialization_v2(self):
+ activation_map = {nn.softmax_v2: 'softmax'}
+ for fn_v2_key in activation_map:
+ fn_v2 = keras.activations.get(fn_v2_key)
+ config = keras.activations.serialize(fn_v2)
+ fn = keras.activations.deserialize(config)
+ assert fn.__name__ == activation_map[fn_v2_key]
+
def test_softmax(self):
x = keras.backend.placeholder(ndim=2)
f = keras.backend.function([x], [keras.activations.softmax(x)])
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 837fd26..064358c 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -32,6 +32,9 @@
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session as session_module
+from tensorflow.python.distribute import distribute_coordinator as dc
+from tensorflow.python.distribute import distribute_coordinator_context as dc_context
+from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.eager import context
from tensorflow.python.eager import function as eager_function
from tensorflow.python.framework import constant_op
@@ -61,7 +64,7 @@
from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variables as variables_module
-
+from tensorflow.python.training import server_lib
from tensorflow.python.util import nest
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util import tf_inspect
@@ -394,8 +397,14 @@
session = default_session
else:
if getattr(_SESSION, 'session', None) is None:
- _SESSION.session = session_module.Session(
- config=get_default_session_config())
+ # We are creating the Session inside a Distribution
+ # Strategy scope.
+ if distribution_strategy_context.has_strategy():
+ configure_and_create_distributed_session(
+ distribution_strategy_context.get_strategy())
+ else:
+ _SESSION.session = session_module.Session(
+ config=get_default_session_config())
session = _SESSION.session
return session
@@ -2884,7 +2893,7 @@
'should be a list or tuple.')
self.inputs = nest.flatten(inputs)
self._outputs_structure = outputs
- self.outputs = nest.flatten(outputs)
+ self.outputs = cast_variables_to_tensor(nest.flatten(outputs))
with ops.control_dependencies(self.outputs):
updates_ops = []
for update in updates:
@@ -3044,14 +3053,13 @@
if not isinstance(updates, (list, tuple)):
raise TypeError('`updates` in a Keras backend function '
'should be a list or tuple.')
+ self.name = name
self.inputs = nest.flatten(inputs)
self._outputs_structure = outputs
- self.outputs = nest.flatten(outputs)
- self.name = name
-
graph = get_graph()
# Consolidate updates
with graph.as_default():
+ self.outputs = cast_variables_to_tensor(nest.flatten(outputs))
with ops.control_dependencies(self.outputs):
# In general, updates should be run after the outputs have been
# computed. However, we can only ensure this when we create
@@ -5201,3 +5209,65 @@
except IOError:
# Except permission denied.
pass
+
+
+def in_multi_worker_mode():
+ """Whether we are operating in a Multi-Worker setting."""
+ tf_config = json.loads(os.environ.get('TF_CONFIG', '{}'))
+ cluster_spec = server_lib.ClusterSpec(tf_config.get('cluster', {}))
+ return tf_config and 'master' not in cluster_spec.jobs
+
+
+def configure_and_create_distributed_session(distribution_strategy):
+ """Configure session config and create a session with it."""
+
+ # TODO(priyag): Throw error if a session already exists.
+ def _create_session(distribution_strategy):
+ """Create the Distributed Strategy session."""
+ session_config = get_default_session_config()
+
+ if is_tpu_strategy(distribution_strategy):
+ # TODO(priyag, yuefengz): Remove this workaround when Distribute
+ # Coordinator is integrated with keras and we can create a session from
+ # there.
+ distribution_strategy.configure(session_config)
+ master = distribution_strategy.extended._tpu_cluster_resolver.master() # pylint: disable=protected-access
+ session = session_module.Session(config=session_config, target=master)
+ else:
+ worker_context = dc_context.get_current_worker_context()
+ if worker_context:
+ dc_session_config = worker_context.session_config
+ # Merge the default session config to the one from distribute
+ # coordinator, which is fine for now since they don't have
+ # conflicting configurations.
+ dc_session_config.MergeFrom(session_config)
+ session = session_module.Session(
+ config=dc_session_config, target=worker_context.master_target)
+ else:
+ distribution_strategy.configure(session_config)
+ session = session_module.Session(config=session_config)
+
+ set_session(session)
+
+ if in_multi_worker_mode():
+ dc.run_distribute_coordinator(
+ _create_session,
+ distribution_strategy,
+ mode=dc.CoordinatorMode.INDEPENDENT_WORKER)
+ else:
+ _create_session(distribution_strategy)
+
+
+def is_tpu_strategy(strategy):
+ """We're executing TPU Strategy."""
+ return strategy is not None and strategy.__class__.__name__ == 'TPUStrategy'
+
+
+def cast_variables_to_tensor(tensors):
+
+ def _cast_variables_to_tensor(tensor):
+ if isinstance(tensor, variables_module.Variable):
+ return array_ops.identity(tensor)
+ return tensor
+
+ return nest.map_structure(_cast_variables_to_tensor, tensors)
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index 3223c89..41fedbb 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -1118,9 +1118,9 @@
ValueError: If histogram_freq is set and no validation data is provided.
@compatibility(eager)
- Using `Tensorboard` callback will work while eager execution is enabled,
- however outputting histogram summaries of weights and gradients is not
- supported, and thus `histogram_freq` will be ignored.
+ Using the `TensorBoard` callback will work when eager execution is enabled,
+ with the restriction that outputting histogram summaries of weights and
+ gradients is not supported. Consequently, `histogram_freq` will be ignored.
@end_compatibility
"""
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index 12d3f09..eee23b1 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -244,8 +244,8 @@
@doc_controls.for_subclass_implementers
def add_weight(self,
- name,
- shape,
+ name=None,
+ shape=None,
dtype=None,
initializer=None,
regularizer=None,
@@ -259,8 +259,8 @@
"""Adds a new variable to the layer.
Arguments:
- name: variable name.
- shape: variable shape.
+ name: Variable name.
+ shape: Variable shape. Defaults to scalar if unspecified.
dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
initializer: initializer instance (callable).
regularizer: regularizer instance (callable).
@@ -297,6 +297,7 @@
ValueError: When giving unsupported dtype and no initializer or when
trainable has been set to True with synchronization set as `ON_READ`.
"""
+ shape = shape or ()
# Validate optional keyword arguments.
for kwarg in kwargs:
if kwarg not in ['getter', 'collections']:
@@ -363,8 +364,10 @@
# TODO(fchollet): in the future, this should be handled at the
# level of variable creation, and weight regularization losses
# should be variable attributes.
- self._handle_weight_regularization(name, variable, regularizer)
-
+ name_in_scope = variable.name[:variable.name.find(':')]
+ self._handle_weight_regularization(name_in_scope,
+ variable,
+ regularizer)
if trainable:
self._trainable_weights.append(variable)
else:
@@ -509,8 +512,15 @@
"""
input_list = nest.flatten(inputs)
# Accept NumPy inputs by converting to Tensors.
- if all(isinstance(x, (np.ndarray, float, int)) for x in input_list):
- inputs = nest.map_structure(ops.convert_to_tensor, inputs)
+ if any(isinstance(x, (np.ndarray, float, int)) for x in input_list):
+ # Don't call `ops.convert_to_tensor` on all `inputs` because
+ # `SparseTensors` can't be converted to `Tensor`.
+ def _convert_non_tensor(x):
+ if isinstance(x, (np.ndarray, float, int)):
+ return ops.convert_to_tensor(x)
+ return x
+
+ inputs = nest.map_structure(_convert_non_tensor, inputs)
input_list = nest.flatten(inputs)
# We will attempt to build a TF graph if & only if all inputs are symbolic.
@@ -541,64 +551,59 @@
# pass to __call__, hence we set previous_mask as the default value.
kwargs['mask'] = previous_mask
- with ops.name_scope(self._name_scope()):
- if not self.built:
+ # Check input assumptions set after layer building, e.g. input shape.
+ if build_graph:
+ # Symbolic execution on symbolic tensors. We will attempt to build
+ # the corresponding TF subgraph inside `backend.get_graph()`
+ input_spec.assert_input_compatibility(self.input_spec, inputs, self.name)
+ graph = backend.get_graph()
+ with graph.as_default(), ops.name_scope(self._name_scope()):
# Build layer if applicable (if the `build` method has been overridden).
self._maybe_build(inputs)
- # We must set self.built since user defined build functions are not
- # constrained to set self.built.
- self.built = True
+ if not self.dynamic:
+ try:
+ outputs = self.call(inputs, *args, **kwargs)
+ except TypeError as e:
+ messages = ('`tf.Tensor` as a Python `bool` is not allowed',
+ 'Tensor objects are only iterable when eager')
+ exception_str = str(e)
+ for msg in messages:
+ if msg in exception_str:
+ raise TypeError('You are attempting to use Python control '
+ 'flow in a layer that was not declared to be '
+ 'dynamic. Pass `dynamic=True` to the class '
+ 'constructor.\nEncountered error:\n"""\n' +
+ exception_str + '\n"""')
+ raise
+ else:
+ # We will use static shape inference to return symbolic tensors
+ # matching the specifications of the layer outputs.
+ # Since `self.dynamic` is True, we will never attempt to
+ # run the underlying TF graph (which is disconnected).
+ # TODO(fchollet): consider py_func as an alternative, which
+ # would enable us to run the underlying graph if needed.
+ outputs = self._symbolic_call(inputs)
- # Check input assumptions set after layer building, e.g. input shape.
- if build_graph:
- # Symbolic execution on symbolic tensors. We will attempt to build
- # the corresponding TF subgraph inside `backend.get_graph()`
- input_spec.assert_input_compatibility(
- self.input_spec, inputs, self.name)
- graph = backend.get_graph()
- with graph.as_default():
- if not self.dynamic:
- try:
- outputs = self.call(inputs, *args, **kwargs)
- except TypeError as e:
- messages = ('`tf.Tensor` as a Python `bool` is not allowed',
- 'Tensor objects are only iterable when eager')
- exception_str = str(e)
- for msg in messages:
- if msg in exception_str:
- raise TypeError('You are attempting to use Python control '
- 'flow in a layer that was not declared to be '
- 'dynamic. Pass `dynamic=True` to the class '
- 'constructor.\nEncountered error:\n"""\n' +
- exception_str + '\n"""')
- raise
- else:
- # We will use static shape inference to return symbolic tensors
- # matching the specifications of the layer outputs.
- # Since `self.dynamic` is True, we will never attempt to
- # run the underlying TF graph (which is disconnected).
- # TODO(fchollet): consider py_func as an alternative, which
- # would enable us to run the underlying graph if needed.
- outputs = self._symbolic_call(inputs)
-
- if outputs is None:
- raise ValueError('A layer\'s `call` method should return a '
- 'Tensor or a list of Tensors, not None '
- '(layer: ' + self.name + ').')
- if base_layer_utils.have_all_keras_metadata(inputs):
- inputs, outputs = self._set_connectivity_metadata_(
- inputs, outputs, args, kwargs)
- self._handle_activity_regularization(inputs, outputs)
- self._set_mask_metadata(inputs, outputs, previous_mask)
- if hasattr(self, '_set_inputs') and not self.inputs:
- # Subclassed network: explicitly set metadata normally set by
- # a call to self._set_inputs().
- # TODO(b/120997007): This should be done in Eager as well, but
- # causes garbage collection issues because of the placeholders
- # created on the default Keras graph.
- self._set_inputs(inputs, outputs)
- else:
- # Eager execution on data tensors.
+ if outputs is None:
+ raise ValueError('A layer\'s `call` method should return a '
+ 'Tensor or a list of Tensors, not None '
+ '(layer: ' + self.name + ').')
+ if base_layer_utils.have_all_keras_metadata(inputs):
+ inputs, outputs = self._set_connectivity_metadata_(
+ inputs, outputs, args, kwargs)
+ self._handle_activity_regularization(inputs, outputs)
+ self._set_mask_metadata(inputs, outputs, previous_mask)
+ if hasattr(self, '_set_inputs') and not self.inputs:
+ # Subclassed network: explicitly set metadata normally set by
+ # a call to self._set_inputs().
+ # TODO(b/120997007): This should be done in Eager as well, but
+ # causes garbage collection issues because of the placeholders
+ # created on the default Keras graph.
+ self._set_inputs(inputs, outputs)
+ else:
+ # Eager execution on data tensors.
+ with ops.name_scope(self._name_scope()):
+ self._maybe_build(inputs)
outputs = self.call(inputs, *args, **kwargs)
self._handle_activity_regularization(inputs, outputs)
self._set_mask_metadata(inputs, outputs, previous_mask)
@@ -1348,32 +1353,35 @@
self.add_loss(mean_activity_loss, inputs=inputs)
def _set_mask_metadata(self, inputs, outputs, previous_mask):
- if getattr(self, '_compute_output_and_mask_jointly', False):
- # Mask is already computed for Keras Graph Networks.
- return
-
flat_outputs = nest.flatten(outputs)
- if all(getattr(x, '_keras_mask', None) is not None for x in flat_outputs):
- # Mask is already computed by sublayers.
- return
+ mask_already_computed = (
+ getattr(self, '_compute_output_and_mask_jointly', False) or
+ all(getattr(x, '_keras_mask', None) is not None for x in flat_outputs))
- if hasattr(self, 'compute_mask'):
- output_masks = self.compute_mask(inputs, previous_mask)
- # `compute_mask` can return a single `None` even when a Layer
- # has multiple outputs.
- if output_masks is None:
- flat_masks = [None for _ in flat_outputs]
+ if not mask_already_computed:
+ if hasattr(self, 'compute_mask'):
+ output_masks = self.compute_mask(inputs, previous_mask)
+ # `compute_mask` can return a single `None` even when a Layer
+ # has multiple outputs.
+ if output_masks is None:
+ flat_masks = [None for _ in flat_outputs]
+ else:
+ flat_masks = nest.flatten(output_masks)
else:
- flat_masks = nest.flatten(output_masks)
- else:
- flat_masks = [None for _ in flat_outputs]
+ flat_masks = [None for _ in flat_outputs]
- for output, mask in zip(flat_outputs, flat_masks):
- try:
- output._keras_mask = mask
- except AttributeError:
- # C Type such as np.ndarray.
- pass
+ for output, mask in zip(flat_outputs, flat_masks):
+ try:
+ output._keras_mask = mask
+ except AttributeError:
+ # C Type such as np.ndarray.
+ pass
+
+ if tf_utils.are_all_symbolic_tensors(flat_outputs):
+ for output in flat_outputs:
+ if getattr(output, '_keras_mask', None) is not None:
+ # Do not track masks for `TensorFlowOpLayer` construction.
+ output._keras_mask._keras_history_checked = True
def _set_connectivity_metadata_(self, inputs, outputs, args, kwargs):
call_convention = getattr(
@@ -1568,6 +1576,9 @@
def _maybe_build(self, inputs):
# Check input assumptions set before layer building, e.g. input rank.
+ if self.built:
+ return
+
input_spec.assert_input_compatibility(
self.input_spec, inputs, self.name)
input_list = nest.flatten(inputs)
@@ -1582,6 +1593,9 @@
# Only call `build` if the user has manually overridden the build method.
if not hasattr(self.build, '_is_default'):
self.build(input_shapes)
+ # We must set self.built since user defined build functions are not
+ # constrained to set self.built.
+ self.built = True
def _symbolic_call(self, inputs):
input_shapes = nest.map_structure(lambda x: x.shape, inputs)
diff --git a/tensorflow/python/keras/engine/base_layer_test.py b/tensorflow/python/keras/engine/base_layer_test.py
index faeb4c5..109fc1f 100644
--- a/tensorflow/python/keras/engine/base_layer_test.py
+++ b/tensorflow/python/keras/engine/base_layer_test.py
@@ -225,6 +225,28 @@
model(np.zeros((2, 4), dtype='float32'))
self.assertTrue(model.built)
+ @test_util.run_in_graph_and_eager_modes
+ def test_default_add_weight(self):
+
+ class TestLayer(keras.layers.Layer):
+
+ def __init__(self):
+ super(TestLayer, self).__init__()
+ self.default_weight = self.add_weight()
+ self.weight_without_name = self.add_weight(shape=(3, 4))
+ self.regularized_weight_without_name = self.add_weight(
+ shape=(3, 4), regularizer='l2')
+
+ layer = TestLayer()
+ self.assertEqual(layer.default_weight.shape.as_list(), [])
+ self.assertEqual(layer.weight_without_name.shape.as_list(), [3, 4])
+ self.assertEqual(layer.default_weight.dtype.name, 'float32')
+ self.assertEqual(layer.weight_without_name.dtype.name, 'float32')
+ self.assertEqual(len(layer.losses), 1)
+ if not context.executing_eagerly():
+ # Cannot access tensor.name in eager execution.
+ self.assertTrue('Variable_2/Regularizer' in layer.losses[0].name)
+
def test_learning_phase_freezing_for_layers(self):
# This test is only meant to run in graph functions mode (ambient eager).
# In forced eager, `model.predict` ignores the global learning phase
@@ -347,6 +369,27 @@
function_name = last_entry[2]
self.assertEqual(function_name, 'easily_identifiable_name')
+ # Cannot be enabled with `run_eagerly=True`, see b/123904578
+ @test_util.run_all_in_graph_and_eager_modes
+ def test_layer_can_return_variable(self):
+
+ class ComputeSum(keras.layers.Layer):
+
+ def __init__(self):
+ super(ComputeSum, self).__init__()
+ self.total = variables.Variable(
+ initial_value=array_ops.zeros((1, 1)), trainable=False)
+ if not context.executing_eagerly():
+ keras.backend.get_session().run(self.total.initializer)
+
+ def call(self, inputs):
+ self.total.assign_add(inputs)
+ return self.total
+
+ inputs = keras.Input(shape=(1,))
+ model = keras.Model(inputs, ComputeSum()(inputs))
+ model.predict(np.ones((1, 1)))
+
@test_util.run_all_in_graph_and_eager_modes
class NestedTrackingTest(test.TestCase):
@@ -434,6 +477,34 @@
self.assertEqual(len(layer.updates), 3)
+@test_util.run_all_in_graph_and_eager_modes
+class NameScopingTest(keras_parameterized.TestCase):
+
+ def test_name_scope_layer(self):
+ x = keras.backend.placeholder(shape=(10, 10))
+ layer = keras.layers.Dense(10, name='MyName')
+ layer(x)
+ self.assertEqual(layer.bias.name, 'MyName/bias:0')
+ self.assertEqual(layer.kernel.name, 'MyName/kernel:0')
+
+ def test_name_scope_sublayer(self):
+ x = keras.backend.placeholder(shape=(10, 10))
+ layer = keras.layers.Dense(
+ 10, activation=keras.layers.ReLU(name='MyAct'), name='MyName2')
+ y = layer(x)
+ self.assertEqual(layer.bias.name, 'MyName2/bias:0')
+ self.assertEqual(layer.kernel.name, 'MyName2/kernel:0')
+ self.assertEqual(y.name, 'MyName2/MyAct/Relu:0')
+
+ def test_name_scope_tf_tensor(self):
+ x = ops.convert_to_tensor(np.ones((10, 10)))
+ layer = keras.layers.Dense(
+ 10, activation=keras.layers.ReLU(name='MyAct'), name='MyName3')
+ layer(x)
+ self.assertEqual(layer.bias.name, 'MyName3/bias:0')
+ self.assertEqual(layer.kernel.name, 'MyName3/kernel:0')
+
+
if __name__ == '__main__':
ops.enable_eager_execution()
test.main()
diff --git a/tensorflow/python/keras/engine/distributed_training_utils.py b/tensorflow/python/keras/engine/distributed_training_utils.py
index 70de386..ccbc1c9 100644
--- a/tensorflow/python/keras/engine/distributed_training_utils.py
+++ b/tensorflow/python/keras/engine/distributed_training_utils.py
@@ -20,11 +20,10 @@
import numpy as np
-from tensorflow.python.client import session as session_module
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.distribute import distribute_coordinator_context as dc_context
-from tensorflow.python.distribute import distribute_lib
+from tensorflow.python.distribute import reduce_util
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -105,7 +104,7 @@
grouped_inputs)
if with_loss_tensor:
# reduce loss tensor before adding it to the list of fetches
- loss = distribution_strategy.reduce(distribute_lib.get_loss_reduction(),
+ loss = distribution_strategy.reduce(reduce_util.ReduceOp.SUM,
grouped_outputs[0])
all_outputs = flatten_perdevice_values(distribution_strategy,
grouped_outputs[1:])
@@ -351,34 +350,6 @@
_wait_for_variable_initialization(session)
-def configure_and_create_session(distribution_strategy):
- """Configure session config and create a session with it."""
- # TODO(priyag): Throw error if a session already exists.
- session_config = K.get_default_session_config()
-
- if is_tpu_strategy(distribution_strategy):
- # TODO(priyag, yuefengz): Remove this workaround when Distribute
- # Coordinator is integrated with keras and we can create a session from
- # there.
- distribution_strategy.configure(session_config)
- master = distribution_strategy.extended._tpu_cluster_resolver.master() # pylint: disable=protected-access
- session = session_module.Session(config=session_config, target=master)
- else:
- worker_context = dc_context.get_current_worker_context()
- if worker_context:
- dc_session_config = worker_context.session_config
- # Merge the default session config to the one from distribute coordinator,
- # which is fine for now since they don't have conflicting configurations.
- dc_session_config.MergeFrom(session_config)
- session = session_module.Session(
- config=dc_session_config, target=worker_context.master_target)
- else:
- distribution_strategy.configure(session_config)
- session = session_module.Session(config=session_config)
-
- K.set_session(session)
-
-
def validate_inputs(x, y, distribution_strategy, allow_partial_batch=False):
"""Validate inputs when using DistributionStrategy.
@@ -752,21 +723,46 @@
def _make_execution_function(model, mode):
- """Makes function to run one step of distributed model execution."""
- if context.executing_eagerly():
- return _make_eager_execution_function(model, mode)
-
+ """Makes or reuses function to run one step of distributed model execution."""
strategy = model._distribution_strategy
- if not get_distributed_model(model, mode):
- if model._compile_distribution:
- clone_model_on_replicas(model, strategy, mode)
- else:
- _build_distributed_network(model, strategy, mode)
+
+ distributed_model = get_distributed_model(model, mode)
+ # If distributed model for a particular `mode` is already built, use the
+ # `_distribution_function` on that distributed model.
+ if distributed_model:
+ return distributed_model._distributed_function
+
+ # If distributed_model is not built, create one for `mode`.
+ if model._compile_distribution:
+ clone_model_on_replicas(model, strategy, mode)
+ else:
+ _build_distributed_network(model, strategy, mode)
+
+ # We've just created the distributed model. So `distributed_model` should be
+ # not None.
+ distributed_model = get_distributed_model(model, mode)
+ assert distributed_model
+
+ # Also create an execution fuction on that distributed model.
+ if context.executing_eagerly():
+ distributed_function = _make_eager_execution_function(model, mode)
+ else:
+ distributed_function = _make_graph_execution_function(model, mode)
+
+ # We cache the distributed execution function on the model since creating
+ # distributed models and exection functions are expensive.
+ distributed_model._distributed_function = distributed_function
+ return distributed_function
+
+
+def _make_graph_execution_function(model, mode):
+ """Makes function to run one step of distributed model in graph mode."""
def _per_device_function(model):
f = model._make_execution_function(mode)
return (f.inputs, f.outputs, f.updates_op, f.session_kwargs)
+ strategy = model._distribution_strategy
with strategy.scope():
# Create train ops on each of the devices when we call
# `_per_device_fit_function`.
@@ -802,19 +798,13 @@
def _make_eager_execution_function(model, mode):
"""Makes function to run one step of distributed model eager execution."""
- strategy = model._distribution_strategy
- if not get_distributed_model(model, mode):
- if model._compile_distribution:
- clone_model_on_replicas(model, strategy, mode)
- else:
- _build_distributed_network(model, strategy, mode)
-
def _per_device_function(model):
f = model._make_execution_function(mode)
return (f.inputs, f.outputs)
# NOTE(priyag): Try creating a new FuncGraph within DS scope instead of using
# the global one.
+ strategy = model._distribution_strategy
with K.get_graph().as_default(), strategy.scope():
# Create train ops on each of the devices when we call
# `_per_device_fit_function`.
@@ -880,21 +870,18 @@
def get_distributed_model(model, mode):
- if mode is ModeKeys.TRAIN:
- return model._distributed_model_train
- elif mode is ModeKeys.TEST:
- return model._distributed_model_test
- elif mode is ModeKeys.PREDICT:
- return model._distributed_model_predict
+ key = _generate_cache_key(mode)
+ return model._distributed_model_cache.get(key, None)
def set_distributed_model(model, mode, distributed_model):
- if mode is ModeKeys.TRAIN:
- model._distributed_model_train = distributed_model
- elif mode is ModeKeys.TEST:
- model._distributed_model_test = distributed_model
- elif mode is ModeKeys.PREDICT:
- model._distributed_model_predict = distributed_model
+ key = _generate_cache_key(mode)
+ model._distributed_model_cache[key] = distributed_model
+
+
+def _generate_cache_key(mode):
+ key = hash(mode)
+ return key
@tf_contextlib.contextmanager
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py
index 0835e7c..559df56 100644
--- a/tensorflow/python/keras/engine/network.py
+++ b/tensorflow/python/keras/engine/network.py
@@ -22,7 +22,6 @@
import copy
import json
import os
-import weakref
import numpy as np
from six.moves import zip # pylint: disable=redefined-builtin
@@ -207,8 +206,8 @@
self._outbound_nodes = []
self._inbound_nodes = []
- self._checkpointable_saver = checkpointable_utils.CheckpointableSaver(
- weakref.ref(self))
+ self._checkpointable_saver = (
+ checkpointable_utils.saver_with_op_caching(self))
@checkpointable.no_automatic_dependency_tracking
def _init_graph_network(self, inputs, outputs, name=None):
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 53c6a88..725e20b 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -24,6 +24,7 @@
from tensorflow.python import tf2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.distribute import distribute_coordinator as dc
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
@@ -46,6 +47,7 @@
from tensorflow.python.keras.utils.generic_utils import slice_arrays
from tensorflow.python.keras.utils.losses_utils import squeeze_or_expand_dimensions
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.losses import losses_impl
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import optimizer as tf_optimizer_module
from tensorflow.python.training.checkpointable import base as checkpointable
@@ -272,11 +274,12 @@
self.target_tensors = target_tensors
# Set DistributionStrategy specific parameters.
- for mode in [ModeKeys.TRAIN, ModeKeys.TEST, ModeKeys.PREDICT]:
- distributed_training_utils.set_distributed_model(self, mode, None)
+ self._distributed_model_cache = {}
+
if self._distribution_strategy is not None:
- distributed_training_utils.configure_and_create_session(
- self._distribution_strategy)
+ # Ensures a Session is created and configured correctly for Distribution
+ # Strategy.
+ K.configure_and_create_distributed_session(self._distribution_strategy)
# Initialize model metric attributes.
self._init_metric_attributes()
if not self.built or not self.inputs or not self.outputs:
@@ -311,8 +314,10 @@
' outputs, but you passed loss=' + str(loss))
loss_functions = [training_utils.get_loss_function(l) for l in loss]
else:
- loss_function = training_utils.get_loss_function(loss)
- loss_functions = [loss_function for _ in range(len(self.outputs))]
+ loss_functions = [
+ training_utils.get_loss_function(loss)
+ for _ in range(len(self.outputs))
+ ]
self.loss_functions = loss_functions
skip_target_indices = []
@@ -483,13 +488,15 @@
'_loss'] = output_loss
# Keep track of stateful result tensor and function for the loss.
- mean_wrapped_loss = metrics_module.MeanMetricWrapper(
+ # Reset reduction here as metric wrapper will take care of that.
+ loss_fn.reduction = losses_impl.ReductionV2.NONE
+ output_loss_metric = metrics_module.SumOverBatchSizeMetricWrapper(
loss_fn, name=loss_fn.name)
- result_tensor = self._call_metric_fn(mean_wrapped_loss, y_true,
+ result_tensor = self._call_metric_fn(output_loss_metric, y_true,
y_pred, sample_weight, mask)
self._compile_stateful_metrics_tensors[self.output_names[i] +
'_loss'] = result_tensor
- self._compile_stateful_metric_functions.append(mean_wrapped_loss)
+ self._compile_stateful_metric_functions.append(output_loss_metric)
self._compile_metrics_names.append(self.output_names[i] + '_loss')
if total_loss is None:
@@ -531,6 +538,7 @@
self._fit_function = None
self._eval_function = None
+ self._predict_function = None
self.train_function = None
self.test_function = None
self.predict_function = None
@@ -786,23 +794,52 @@
# Case 1: distribution strategy.
if self._distribution_strategy:
- return training_distributed.fit_distributed(
- self,
- x=x,
- y=y,
- batch_size=batch_size,
- epochs=epochs,
- verbose=verbose,
- callbacks=callbacks,
- validation_split=validation_split,
- validation_data=validation_data,
- shuffle=shuffle,
- class_weight=class_weight,
- sample_weight=sample_weight,
- initial_epoch=initial_epoch,
- steps_per_epoch=steps_per_epoch,
- validation_steps=validation_steps,
- validation_freq=validation_freq)
+ if K.in_multi_worker_mode():
+ # Multi-Worker mode runs the Keras training loop on multiple
+ # servers via the Distribute Coordinator.
+ def _worker_fn(_):
+ """Run training inside the distributed coordinator."""
+ return training_distributed.fit_distributed(
+ self,
+ x=x,
+ y=y,
+ batch_size=batch_size,
+ epochs=epochs,
+ verbose=verbose,
+ callbacks=callbacks,
+ validation_split=validation_split,
+ validation_data=validation_data,
+ shuffle=shuffle,
+ class_weight=class_weight,
+ sample_weight=sample_weight,
+ initial_epoch=initial_epoch,
+ steps_per_epoch=steps_per_epoch,
+ validation_steps=validation_steps,
+ validation_freq=validation_freq)
+
+ # Independent worker only for now.
+ return dc.run_distribute_coordinator(
+ _worker_fn,
+ self._distribution_strategy,
+ mode=dc.CoordinatorMode.INDEPENDENT_WORKER)
+ else:
+ return training_distributed.fit_distributed(
+ self,
+ x=x,
+ y=y,
+ batch_size=batch_size,
+ epochs=epochs,
+ verbose=verbose,
+ callbacks=callbacks,
+ validation_split=validation_split,
+ validation_data=validation_data,
+ shuffle=shuffle,
+ class_weight=class_weight,
+ sample_weight=sample_weight,
+ initial_epoch=initial_epoch,
+ steps_per_epoch=steps_per_epoch,
+ validation_steps=validation_steps,
+ validation_freq=validation_freq)
batch_size = self._validate_or_infer_batch_size(
batch_size, steps_per_epoch, x)
@@ -1017,15 +1054,36 @@
"""
# Case 1: distribution strategy.
if self._distribution_strategy:
- return training_distributed.evaluate_distributed(
- self,
- x=x,
- y=y,
- batch_size=batch_size,
- verbose=verbose,
- sample_weight=sample_weight,
- steps=steps,
- callbacks=callbacks)
+ if K.in_multi_worker_mode():
+ # Multi-Worker mode runs the Keras evaluation loop on multiple
+ # servers via the Distribute Coordinator.
+ def _worker_fn(_):
+ """Run evaluation inside the distributed coordinator."""
+ return training_distributed.evaluate_distributed(
+ self,
+ x=x,
+ y=y,
+ batch_size=batch_size,
+ verbose=verbose,
+ sample_weight=sample_weight,
+ steps=steps,
+ callbacks=callbacks)
+
+ # Independent worker only for now.
+ return dc.run_distribute_coordinator(
+ _worker_fn,
+ self._distribution_strategy,
+ mode=dc.CoordinatorMode.INDEPENDENT_WORKER)
+ else:
+ return training_distributed.evaluate_distributed(
+ self,
+ x=x,
+ y=y,
+ batch_size=batch_size,
+ verbose=verbose,
+ sample_weight=sample_weight,
+ steps=steps,
+ callbacks=callbacks)
batch_size = self._validate_or_infer_batch_size(batch_size, steps, x)
@@ -1727,9 +1785,10 @@
batch_size = 32
return batch_size
- @property
- def _default_save_signature(self):
- return saving_utils.trace_model_call(self)
+ def _list_functions_for_serialization(self):
+ return {
+ '_default_save_signature': saving_utils.trace_model_call(self)
+ }
def _set_sample_weight_attributes(self, sample_weight_mode,
skip_target_weighing_indices):
@@ -1762,7 +1821,7 @@
self._per_output_weighted_metrics = \
training_utils.collect_per_output_metric_info(
weighted_metrics, self.output_names, output_shapes,
- self.loss_functions, self.sample_weights)
+ self.loss_functions, is_weighted=True)
def _add_unique_metric_name(self, metric_name, output_index):
"""Makes the metric name unique and adds it to the model's metric name list.
@@ -2233,18 +2292,14 @@
else:
in_tuple = x
- if shuffle:
- # 1024 is a good buffer size since it is much larger than the average
- # batch size provided by the user and provides sufficient randomness.
- # One thing to keep in mind is the memory usage based on the size of
- # each sample.
- shuffle_buffer = 1024
- else:
- shuffle_buffer = None
ds = strategy.extended.experimental_make_numpy_dataset(in_tuple,
session=session)
- if shuffle_buffer:
- ds = ds.shuffle(shuffle_buffer)
+ if shuffle:
+ # We want a buffer size that is larger than the batch size provided by
+ # the user and provides sufficient randomness. Note that larger
+ # numbers introduce more memory usage based on the size of each
+ # sample.
+ ds = ds.shuffle(max(1024, batch_size * 8))
if repeat:
ds = ds.repeat()
diff --git a/tensorflow/python/keras/engine/training_dataset_test.py b/tensorflow/python/keras/engine/training_dataset_test.py
index 751c265..4d2d68c 100644
--- a/tensorflow/python/keras/engine/training_dataset_test.py
+++ b/tensorflow/python/keras/engine/training_dataset_test.py
@@ -27,6 +27,7 @@
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import test_util as tf_test_util
+from tensorflow.python.keras import callbacks
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras import testing_utils
@@ -34,6 +35,15 @@
from tensorflow.python.platform import tf_logging as logging
+class BatchCounterCallback(callbacks.Callback):
+
+ def __init__(self):
+ self.batch_count = 0
+
+ def on_batch_end(self, *args, **kwargs):
+ self.batch_count += 1
+
+
class TestTrainingWithDatasetIterators(keras_parameterized.TestCase):
@keras_parameterized.run_with_all_model_types
@@ -394,8 +404,11 @@
dataset = dataset_ops.Dataset.from_tensor_slices((inputs, targets))
dataset = dataset.batch(10)
- history = model.fit(dataset, epochs=2, verbose=1)
+ batch_counter = BatchCounterCallback()
+ history = model.fit(dataset, epochs=2, verbose=1, callbacks=[batch_counter])
+
self.assertEqual(len(history.history['loss']), 2)
+ self.assertEqual(batch_counter.batch_count, 20)
model.evaluate(dataset)
out = model.predict(dataset)
self.assertEqual(out.shape[0], 100)
@@ -415,8 +428,11 @@
self.assertEqual(keras.backend.get_value(cardinality.cardinality(dataset)),
cardinality.UNKNOWN)
- history = model.fit(dataset, epochs=2, verbose=1)
+ batch_counter = BatchCounterCallback()
+ history = model.fit(dataset, epochs=2, verbose=1, callbacks=[batch_counter])
+
self.assertEqual(len(history.history['loss']), 2)
+ self.assertEqual(batch_counter.batch_count, 20)
model.evaluate(dataset)
out = model.predict(dataset)
self.assertEqual(out.shape[0], 100)
diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py
index 6816058..4a2a337 100644
--- a/tensorflow/python/keras/engine/training_distributed.py
+++ b/tensorflow/python/keras/engine/training_distributed.py
@@ -22,7 +22,7 @@
import numpy as np
from tensorflow.python.data.experimental.ops import batching
-from tensorflow.python.distribute import distribute_lib
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.distribute import input_lib
from tensorflow.python.distribute import reduce_util as ds_reduce_util
from tensorflow.python.framework import constant_op
@@ -73,13 +73,17 @@
batch_size, mode=ModeKeys.TRAIN))
batch_size = model._validate_or_infer_batch_size(
batch_size, steps_per_epoch, x)
+ steps_name = 'steps_per_epoch'
+ if isinstance(x, dataset_ops.DatasetV2):
+ steps_per_epoch = training_utils.infer_steps_for_dataset(
+ x, steps_per_epoch, steps_name=steps_name)
dataset = model._distribution_standardize_user_data(
x, y,
sample_weight=sample_weight,
class_weight=class_weight,
batch_size=batch_size,
check_steps=True,
- steps_name='steps_per_epoch',
+ steps_name=steps_name,
steps=steps_per_epoch,
validation_split=validation_split,
shuffle=shuffle)
@@ -95,13 +99,17 @@
validation_steps, _ = distributed_training_utils.get_input_params(
model._distribution_strategy, first_valx_value, validation_steps,
batch_size)
+ steps_name = 'validation_steps'
+ if isinstance(val_x, dataset_ops.DatasetV2):
+ validation_steps = training_utils.infer_steps_for_dataset(
+ val_x, validation_steps, steps_name=steps_name)
val_dataset = model._distribution_standardize_user_data(
val_x, val_y,
sample_weight=val_sample_weights,
class_weight=None,
batch_size=batch_size,
check_steps=True,
- steps_name='validation_steps',
+ steps_name=steps_name,
steps=validation_steps,
validation_split=validation_split,
shuffle=shuffle)
@@ -152,12 +160,17 @@
steps, batch_size = distributed_training_utils.get_input_params(
model._distribution_strategy, first_x_value, steps, batch_size)
batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
+ steps_name = 'steps'
+
+ if isinstance(x, dataset_ops.DatasetV2):
+ steps = training_utils.infer_steps_for_dataset(x, steps,
+ steps_name=steps_name)
dataset = model._distribution_standardize_user_data(
x, y,
sample_weight=sample_weight,
batch_size=batch_size,
check_steps=True,
- steps_name='steps',
+ steps_name=steps_name,
steps=steps)
if distributed_training_utils.is_tpu_strategy(model._distribution_strategy):
@@ -188,11 +201,15 @@
model._distribution_strategy, first_x_value, steps,
batch_size, mode=ModeKeys.PREDICT)
batch_size = model._validate_or_infer_batch_size(batch_size, steps, x)
+ steps_name = 'steps'
+ if isinstance(x, dataset_ops.DatasetV2):
+ steps = training_utils.infer_steps_for_dataset(x, steps,
+ steps_name=steps_name)
dataset = model._distribution_standardize_user_data(
x,
batch_size=batch_size,
check_steps=True,
- steps_name='steps',
+ steps_name=steps_name,
steps=steps,
repeat=False,
allow_partial_batch=True)
@@ -250,6 +267,7 @@
Raises:
ValueError: in case of invalid arguments.
"""
+ mode = ModeKeys.TRAIN
# TODO(fchollet): add support for `steps_per_epoch=None` in TPU loops.
current_strategy = model._distribution_strategy
iterator = distributed_training_utils.get_iterator(dataset, current_strategy)
@@ -270,15 +288,16 @@
inputs, targets = inputs
if model._compile_distribution:
distributed_training_utils.clone_model_on_replicas(
- model, current_strategy, ModeKeys.TRAIN, inputs=inputs,
- targets=targets)
+ model, current_strategy, mode, inputs=inputs, targets=targets)
else:
distributed_training_utils._build_distributed_network(
- model, current_strategy, ModeKeys.TRAIN, inputs, targets)
+ model, current_strategy, mode, inputs, targets)
(grouped_inputs, grouped_outputs, grouped_updates,
grouped_session_args) = current_strategy.extended.call_for_each_replica(
- _per_device_fit_function, args=(model._distributed_model_train,))
+ _per_device_fit_function,
+ args=(distributed_training_utils.get_distributed_model(
+ model, ModeKeys.TRAIN),))
(all_inputs, all_outputs, all_updates,
all_session_args) = distributed_training_utils.unwrap_values(
current_strategy, grouped_inputs, grouped_outputs,
@@ -292,7 +311,7 @@
for label, output in zip(out_labels, combined_fn.outputs):
if label == 'loss':
- reduce_op = distribute_lib.get_loss_reduction()
+ reduce_op = ds_reduce_util.ReduceOp.SUM
else:
# We reduce all other metrics using mean for now. This is temporary
# workaround until new metrics are in place.
@@ -329,8 +348,7 @@
do_validation = bool(validation_steps)
if model._compile_distribution:
- distributed_training_utils._copy_weights_to_distributed_model(
- model, ModeKeys.TRAIN)
+ distributed_training_utils._copy_weights_to_distributed_model(model, mode)
callbacks = cbks.configure_callbacks(
callbacks,
@@ -340,7 +358,7 @@
steps_per_epoch=steps_per_epoch,
verbose=verbose,
count_mode='steps',
- mode=ModeKeys.TRAIN)
+ mode=mode)
# Calculate the steps each time on the device.
steps_to_run = [current_strategy.extended.steps_per_run] * (
@@ -349,7 +367,7 @@
steps_to_run.append(
steps_per_epoch % current_strategy.extended.steps_per_run)
- callbacks.on_train_begin()
+ callbacks._call_begin_hook(mode)
for epoch in range(initial_epoch, epochs):
distributed_training_utils._reset_metrics(model)
callbacks.on_epoch_begin(epoch)
@@ -358,7 +376,7 @@
prev_step_count = None
for step_count in steps_to_run:
batch_logs = {'batch': step_index, 'size': 1, 'num_steps': step_count}
- callbacks.on_batch_begin(step_index, batch_logs)
+ callbacks._call_batch_hook(mode, 'begin', step_index, batch_logs)
if prev_step_count is None or step_count != prev_step_count:
steps_per_run.load(step_count, K.get_session())
prev_step_count = step_count
@@ -373,7 +391,7 @@
break
batch_logs.update(outputs)
- callbacks.on_batch_end(step_index, batch_logs)
+ callbacks._call_batch_hook(mode, 'end', step_index, batch_logs)
step_index = step_index + step_count
if callbacks.model.stop_training:
break
@@ -392,7 +410,8 @@
model,
val_dataset,
steps=validation_steps,
- verbose=verbose)
+ verbose=verbose,
+ callbacks=callbacks)
if not isinstance(val_outs, list):
val_outs = [val_outs]
# Same labels assumed.
@@ -402,7 +421,7 @@
callbacks.on_epoch_end(epoch, epoch_logs)
if callbacks.model.stop_training:
break
- callbacks.on_train_end()
+ callbacks._call_end_hook(mode)
if model._compile_distribution:
# Copy the weights back from the replicated model to the original model.
@@ -434,6 +453,7 @@
and/or metrics). The attribute `model.metrics_names` will give you
the display labels for the outputs.
"""
+ mode = ModeKeys.TEST
current_strategy = model._distribution_strategy
iterator = distributed_training_utils.get_iterator(dataset, current_strategy)
scope = distributed_training_utils.distributed_scope(
@@ -450,16 +470,17 @@
"""Clones the model and calls make_eval_function."""
inputs, targets = inputs
if model._compile_distribution:
- distributed_training_utils. clone_model_on_replicas(
- model, current_strategy, mode=ModeKeys.TEST, inputs=inputs,
- targets=targets)
+ distributed_training_utils.clone_model_on_replicas(
+ model, current_strategy, mode=mode, inputs=inputs, targets=targets)
else:
distributed_training_utils._build_distributed_network(
- model, current_strategy, ModeKeys.TEST, inputs, targets)
+ model, current_strategy, mode, inputs, targets)
(grouped_inputs, grouped_outputs, grouped_updates,
grouped_session_args) = current_strategy.extended.call_for_each_replica(
- _per_device_eval_function, args=(model._distributed_model_test,))
+ _per_device_eval_function,
+ args=(distributed_training_utils.get_distributed_model(
+ model, ModeKeys.TEST),))
(all_inputs, all_outputs, all_updates,
all_session_args) = distributed_training_utils.unwrap_values(
@@ -474,7 +495,7 @@
for label, output in zip(model.metrics_names, combined_fn.outputs):
if label == 'loss':
- reduce_op = distribute_lib.get_loss_reduction()
+ reduce_op = ds_reduce_util.ReduceOp.SUM
else:
# We reduce all other metrics using mean for now. This is temporary
# workaround until new metrics are in place.
@@ -503,8 +524,7 @@
progbar = Progbar(target=steps)
if model._compile_distribution:
- distributed_training_utils._copy_weights_to_distributed_model(
- model, ModeKeys.TEST)
+ distributed_training_utils._copy_weights_to_distributed_model(model, mode)
distributed_training_utils._reset_metrics(model)
@@ -517,10 +537,13 @@
verbose=verbose,
count_mode='steps',
mode=ModeKeys.TEST)
+ callbacks._call_begin_hook(mode)
assert steps is not None
outs = [0.] * len(model.metrics_names)
for step in range(steps):
+ batch_logs = {'batch': step, 'size': 1}
+ callbacks._call_batch_hook(mode, 'begin', step, batch_logs)
_, batch_outs = K.get_session().run([test_op, output_tensors])
for i, label in enumerate(model.metrics_names):
if i == 0:
@@ -530,9 +553,13 @@
# For all stateful metrics, the aggregation is handled by mirrored vars.
outs[i] = batch_outs[label]
+ batch_logs = cbks.make_logs(model, batch_logs, outs, mode)
+ callbacks._call_batch_hook(mode, 'end', step, batch_logs)
if verbose >= 1:
progbar.update(step + 1)
+ callbacks._call_end_hook(mode)
+
scope.__exit__(None, None, None)
if len(outs) >= 0:
outs[0] /= (steps)
@@ -563,6 +590,7 @@
or list of arrays of predictions
(if the model has multiple outputs).
"""
+ mode = ModeKeys.PREDICT
dataset_fully_shaped = (distributed_training_utils.
is_dataset_shape_fully_defined(dataset))
padding_handler = None
@@ -572,9 +600,7 @@
# during graph optimization.
padding_handler = padding_util.PartialBatchPaddingHandler(
model._feed_output_shapes)
- batched_dataset = input_lib._get_batched_dataset(dataset)
- batch_size, _, prefetch_buffer = input_lib._get_batched_dataset_attributes(
- batched_dataset)
+ batch_size, _, prefetch_buffer = input_lib._get_dataset_attributes(dataset)
padding_handler.padded_batch_size = batch_size
padding_handler.padding_mask = dataset.reduce(padding_handler.padding_mask,
padding_handler.update_mask)
@@ -606,15 +632,17 @@
def step_fn(ctx, inputs):
"""Clones the model and calls make_predict_function."""
if model._compile_distribution:
- distributed_training_utils. clone_model_on_replicas(
- model, current_strategy, ModeKeys.PREDICT, inputs=inputs)
+ distributed_training_utils.clone_model_on_replicas(
+ model, current_strategy, mode, inputs=inputs)
else:
distributed_training_utils._build_distributed_network(
- model, current_strategy, ModeKeys.PREDICT, inputs)
+ model, current_strategy, mode, inputs)
(grouped_inputs, grouped_outputs, grouped_updates,
grouped_session_args) = current_strategy.extended.call_for_each_replica(
- _per_device_predict_function, args=(model._distributed_model_predict,))
+ _per_device_predict_function,
+ args=(distributed_training_utils.get_distributed_model(
+ model, ModeKeys.PREDICT),))
(all_inputs, all_outputs, all_updates,
all_session_args) = distributed_training_utils.unwrap_values(
@@ -654,8 +682,7 @@
progbar = Progbar(target=steps)
if model._compile_distribution:
- distributed_training_utils._copy_weights_to_distributed_model(
- model, ModeKeys.PREDICT)
+ distributed_training_utils._copy_weights_to_distributed_model(model, mode)
distributed_training_utils._reset_metrics(model)
@@ -667,7 +694,8 @@
steps_per_epoch=steps,
verbose=verbose,
count_mode='steps',
- mode=ModeKeys.PREDICT)
+ mode=mode)
+ callbacks._call_begin_hook(mode)
assert steps is not None
# Since we do not know how many samples we will see, we cannot pre-allocate
@@ -675,13 +703,19 @@
# and concatenate them upon returning.
unconcatenated_outs = [[] for _ in model.outputs]
for step in range(steps):
+ batch_logs = {'batch': step, 'size': 1}
+ callbacks._call_batch_hook(mode, 'begin', step, batch_logs)
_, batch_outs = K.get_session().run([predict_op, output_tensors])
# TODO(priyag): maybe need to unwrap the outputs first for MirroredStrategy.
for i, label in enumerate(model.output_names):
unconcatenated_outs[i].extend(batch_outs[label])
+ batch_logs = cbks.make_logs(model, batch_logs, batch_outs, mode)
+ callbacks._call_batch_hook(mode, 'end', step, batch_logs)
if verbose >= 1:
progbar.update(step + 1)
+ callbacks._call_end_hook(mode)
+
scope.__exit__(None, None, None)
if len(unconcatenated_outs) == 1:
diff --git a/tensorflow/python/keras/engine/training_generator.py b/tensorflow/python/keras/engine/training_generator.py
index da460ee..7c63161 100644
--- a/tensorflow/python/keras/engine/training_generator.py
+++ b/tensorflow/python/keras/engine/training_generator.py
@@ -149,12 +149,14 @@
model, mode, class_weight=class_weight)
# Create the queue for the generator.
- output_generator, enqueuer = _make_enqueued_generator(
- generator,
- workers=workers,
- use_multiprocessing=use_multiprocessing,
- max_queue_size=max_queue_size,
- shuffle=shuffle)
+ enqueuer = None
+ if not is_dataset:
+ generator, enqueuer = _make_enqueued_generator(
+ generator,
+ workers=workers,
+ use_multiprocessing=use_multiprocessing,
+ max_queue_size=max_queue_size,
+ shuffle=shuffle)
num_samples_or_steps, use_steps = _get_num_samples_or_steps(
data, steps_per_epoch)
@@ -208,7 +210,7 @@
step = 0
while step < target_steps:
- batch_data = _get_next_batch(output_generator, mode)
+ batch_data = _get_next_batch(generator, mode)
if batch_data is None:
if not is_dataset:
# We ran out of batches while the user passed an iterator (legacy).
@@ -317,10 +319,10 @@
model_iteration, mode=ModeKeys.PREDICT, shuffle=False)
-def _get_next_batch(output_generator, mode):
+def _get_next_batch(generator, mode):
"""Retrieves the next batch of input data."""
try:
- generator_output = next(output_generator)
+ generator_output = next(generator)
except (StopIteration, errors.OutOfRangeError):
return None
if not isinstance(generator_output, tuple):
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index 6be4da7..72c4a29 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -2115,69 +2115,6 @@
self.assertEqual(reference_metric_names, model.metrics_names)
@keras_parameterized.run_all_keras_modes
- def test_metrics_correctness(self):
- model = keras.Sequential()
- model.add(
- keras.layers.Dense(
- 3, activation='relu', input_dim=4, kernel_initializer='ones'))
- model.add(
- keras.layers.Dense(
- 1, activation='sigmoid', kernel_initializer='ones'))
- model.compile(
- loss='mae',
- metrics=['accuracy', metrics_module.BinaryAccuracy()],
- optimizer=RMSPropOptimizer(learning_rate=0.001),
- run_eagerly=testing_utils.should_run_eagerly())
-
- # verify correctness of stateful and stateless metrics.
- x = np.ones((100, 4))
- y = np.ones((100, 1))
- outs = model.evaluate(x, y)
- self.assertEqual(outs[1], 1.)
- self.assertEqual(outs[2], 1.)
-
- y = np.zeros((100, 1))
- outs = model.evaluate(x, y)
- self.assertEqual(outs[1], 0.)
- self.assertEqual(outs[2], 0.)
-
- @keras_parameterized.run_all_keras_modes
- def test_metrics_correctness_with_weighted_metrics(self):
- np.random.seed(1337)
- x = np.array([[[1.], [1.]], [[0.], [0.]]])
- model = keras.models.Sequential()
- model.add(
- keras.layers.TimeDistributed(
- keras.layers.Dense(1, kernel_initializer='ones'),
- input_shape=(2, 1)))
- model.compile(
- RMSPropOptimizer(learning_rate=0.001),
- loss='mse',
- sample_weight_mode='temporal',
- weighted_metrics=['accuracy', 'mse'],
- run_eagerly=testing_utils.should_run_eagerly())
- y = np.array([[[1.], [1.]], [[1.], [1.]]])
-
- outs = model.evaluate(x, y)
- self.assertEqual(outs, [0.5, 0.5, 0.5])
-
- w = np.array([[0., 0.], [0., 0.]])
- outs = model.evaluate(x, y, sample_weight=w)
- self.assertEqual(outs, [0., 0., 0.])
-
- w = np.array([[3., 4.], [1., 2.]])
- outs = model.evaluate(x, y, sample_weight=w)
- self.assertArrayNear(outs, [0.75, 0.7, 0.3], .001)
-
- # Verify that metric value is same with arbitrary weights and batch size.
- x = np.random.random((50, 2, 1))
- y = np.random.random((50, 2, 1))
- w = np.random.random((50, 2))
- mse1 = model.evaluate(x, y, sample_weight=w, batch_size=5)[2]
- mse2 = model.evaluate(x, y, sample_weight=w, batch_size=10)[2]
- self.assertNear(mse1, mse2, err=1e-7)
-
- @keras_parameterized.run_all_keras_modes
def test_metric_state_reset_between_fit_and_evaluate(self):
model = keras.Sequential()
model.add(keras.layers.Dense(3, activation='relu', input_dim=4))
diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py
index 4c913b5..22f976a 100644
--- a/tensorflow/python/keras/engine/training_utils.py
+++ b/tensorflow/python/keras/engine/training_utils.py
@@ -21,7 +21,6 @@
import abc
import collections
from collections import OrderedDict
-import copy
import numpy as np
import six
@@ -521,7 +520,7 @@
output_names,
output_shapes,
loss_fns,
- sample_weights=None):
+ is_weighted=False):
"""Maps metric names and functions to model outputs.
Arguments:
@@ -529,7 +528,7 @@
output_names: a list of the names (strings) of model outputs.
output_shapes: a list of the shapes (strings) of model outputs.
loss_fns: a list of the loss functions corresponding to the model outputs.
- sample_weights: a list of weights to be applied on the model outputs.
+ is_weighted: Boolean indicating whether the given metrics are weighted.
Returns:
A list (one entry per model output) of dicts.
@@ -553,7 +552,12 @@
return [{} for _ in output_names]
if isinstance(metrics, list):
# we then apply all metrics to all outputs.
- nested_metrics = [copy.copy(metrics) for _ in output_names]
+ if len(output_names) > 1:
+ nested_metrics = []
+ for _ in output_names:
+ nested_metrics.append([metrics_module.clone_metric(m) for m in metrics])
+ else:
+ nested_metrics = [metrics]
elif isinstance(metrics, dict):
nested_metrics = []
for name in output_names:
@@ -569,9 +573,7 @@
for i, metrics in enumerate(nested_metrics):
metrics_dict = OrderedDict()
for metric in metrics:
- weighted = False if (sample_weights is None) else (
- sample_weights[i] is not None)
- metric_name = get_metric_name(metric, weighted)
+ metric_name = get_metric_name(metric, is_weighted)
metric_fn = get_metric_function(
metric, output_shape=output_shapes[i], loss_fn=loss_fns[i])
@@ -1342,7 +1344,7 @@
if size == cardinality.INFINITE and steps is None:
raise ValueError('When passing an infinitely repeating dataset, you '
'must specify the `%s` argument.' % (steps_name,))
- if size != cardinality.UNKNOWN:
+ if size >= 0:
if steps is not None and steps * epochs > size:
if epochs > 1:
raise ValueError('The dataset you passed contains %s batches, but you '
diff --git a/tensorflow/python/keras/integration_test.py b/tensorflow/python/keras/integration_test.py
index dc8d1de..03d7e89 100644
--- a/tensorflow/python/keras/integration_test.py
+++ b/tensorflow/python/keras/integration_test.py
@@ -18,12 +18,15 @@
from __future__ import division
from __future__ import print_function
+import os
+
import numpy as np
from tensorflow.python import keras
from tensorflow.python.framework import dtypes
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import testing_utils
+from tensorflow.python.ops import nn_ops as nn
from tensorflow.python.ops import rnn_cell
from tensorflow.python.platform import test
@@ -197,5 +200,47 @@
self.assertEqual(predictions.shape, (x_train.shape[0], 2))
+@keras_parameterized.run_all_keras_modes
+class ActivationV2IntegrationTest(keras_parameterized.TestCase):
+ """Tests activation function V2 in model exporting and loading.
+
+ This test is to verify in TF 2.x, when 'tf.nn.softmax' is used as an
+ activition function, its model exporting and loading work as expected.
+ Check b/123041942 for details.
+ """
+
+ def test_serialization_v2_model(self):
+ np.random.seed(1337)
+ (x_train, y_train), _ = testing_utils.get_test_data(
+ train_samples=100,
+ test_samples=0,
+ input_shape=(10,),
+ num_classes=2)
+ y_train = keras.utils.to_categorical(y_train)
+
+ model = keras.Sequential([
+ keras.layers.Flatten(input_shape=x_train.shape[1:]),
+ keras.layers.Dense(10, activation=nn.relu),
+ # To mimic 'tf.nn.softmax' used in TF 2.x.
+ keras.layers.Dense(y_train.shape[-1], activation=nn.softmax_v2),
+ ])
+
+ # Check if 'softmax' is in model.get_config().
+ last_layer_activation = model.get_layer(index=2).get_config()['activation']
+ self.assertEqual(last_layer_activation, 'softmax')
+
+ model.compile(loss='categorical_crossentropy',
+ optimizer=keras.optimizer_v2.adam.Adam(0.005),
+ metrics=['accuracy'],
+ run_eagerly=testing_utils.should_run_eagerly())
+ model.fit(x_train, y_train, epochs=2, batch_size=10,
+ validation_data=(x_train, y_train),
+ verbose=2)
+
+ output_path = keras.saving.saved_model.export(
+ model, os.path.join(self.get_temp_dir(), 'tf_keras_saved_model'))
+ loaded_model = keras.saving.saved_model.load_from_saved_model(output_path)
+ self.assertEqual(model.summary(), loaded_model.summary())
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py
index 221a9ba..b33c328 100644
--- a/tensorflow/python/keras/layers/core.py
+++ b/tensorflow/python/keras/layers/core.py
@@ -143,9 +143,12 @@
training = K.learning_phase()
def dropped_inputs():
- return nn.dropout(inputs, 1 - self.rate,
- noise_shape=self._get_noise_shape(inputs),
- seed=self.seed)
+ return nn.dropout(
+ inputs,
+ noise_shape=self._get_noise_shape(inputs),
+ seed=self.seed,
+ rate=self.rate)
+
output = tf_utils.smart_cond(training,
dropped_inputs,
lambda: array_ops.identity(inputs))
diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py
index 9404543..79c6197 100644
--- a/tensorflow/python/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/layers/recurrent.py
@@ -2152,16 +2152,11 @@
count=3)
inputs *= self._dropout_mask[0]
- experimental_api_name = 'gru_' + str(uuid.uuid4())
- defun_standard_gru = _generate_defun_backend(
- experimental_api_name, _CPU_DEVICE_NAME, standard_gru)
- defun_cudnn_gru = _generate_defun_backend(
- experimental_api_name, _GPU_DEVICE_NAME, cudnn_gru)
if ops.executing_eagerly_outside_functions():
# Under eager context, the device placement is already known. Prefer the
# GPU implementation when GPU is available.
if context.num_gpus() > 0:
- last_output, outputs, new_h, runtime = defun_cudnn_gru(
+ last_output, outputs, new_h, runtime = cudnn_gru(
inputs=inputs,
init_h=initial_state[0],
kernel=self.cell.kernel,
@@ -2169,7 +2164,7 @@
bias=self.cell.bias,
time_major=self.time_major)
else:
- last_output, outputs, new_h, runtime = defun_standard_gru(
+ last_output, outputs, new_h, runtime = standard_gru(
inputs=inputs,
init_h=initial_state[0],
kernel=self.cell.kernel,
@@ -2179,6 +2174,11 @@
recurrent_activation=self.recurrent_activation,
time_major=self.time_major)
else:
+ experimental_api_name = 'gru_' + str(uuid.uuid4())
+ defun_standard_gru = _generate_defun_backend(
+ experimental_api_name, _CPU_DEVICE_NAME, standard_gru)
+ defun_cudnn_gru = _generate_defun_backend(
+ experimental_api_name, _GPU_DEVICE_NAME, cudnn_gru)
# Call the normal GRU impl and register the CuDNN impl function. The
# grappler will kick in during session execution to optimize the graph.
last_output, outputs, new_h, runtime = defun_standard_gru(
@@ -3112,29 +3112,29 @@
inputs *= self._dropout_mask[0]
- # Each time a defun function is called, we will give a unique identifiable
- # API name, so that the grappler won't get confused when it sees multiple
- # LSTM layer added into same graph, and it will be able to pair up the
- # different implementations across them.
- experimental_api_name = 'lstm_' + str(uuid.uuid4())
- defun_standard_lstm = _generate_defun_backend(
- experimental_api_name, _CPU_DEVICE_NAME, standard_lstm)
- defun_cudnn_lstm = _generate_defun_backend(
- experimental_api_name, _GPU_DEVICE_NAME, cudnn_lstm)
-
if ops.executing_eagerly_outside_functions():
# Under eager context, the device placement is already known. Prefer the
# GPU implementation here.
if context.num_gpus() > 0:
- last_output, outputs, new_h, new_c, runtime = defun_cudnn_lstm(
+ last_output, outputs, new_h, new_c, runtime = cudnn_lstm(
inputs, initial_state[0], initial_state[1], self.cell.kernel,
self.cell.recurrent_kernel, self.cell.bias, self.time_major)
else:
- last_output, outputs, new_h, new_c, runtime = defun_standard_lstm(
+ last_output, outputs, new_h, new_c, runtime = standard_lstm(
inputs, initial_state[0], initial_state[1], self.cell.kernel,
self.cell.recurrent_kernel, self.cell.bias, self.activation,
self.recurrent_activation, self.time_major)
else:
+ # Each time a `tf.function` is called, we will give it a unique
+ # identifiable API name, so that Grappler won't get confused when it
+ # sees multiple LSTM layers added into same graph, and it will be able
+ # to pair up the different implementations across them.
+ experimental_api_name = 'lstm_' + str(uuid.uuid4())
+ defun_standard_lstm = _generate_defun_backend(
+ experimental_api_name, _CPU_DEVICE_NAME, standard_lstm)
+ defun_cudnn_lstm = _generate_defun_backend(
+ experimental_api_name, _GPU_DEVICE_NAME, cudnn_lstm)
+
# Call the normal LSTM impl and register the CuDNN impl function. The
# grappler will kick in during session execution to optimize the graph.
last_output, outputs, new_h, new_c, runtime = defun_standard_lstm(
diff --git a/tensorflow/python/keras/layers/tensorflow_op_layer_test.py b/tensorflow/python/keras/layers/tensorflow_op_layer_test.py
index b83508e..8941d86 100644
--- a/tensorflow/python/keras/layers/tensorflow_op_layer_test.py
+++ b/tensorflow/python/keras/layers/tensorflow_op_layer_test.py
@@ -158,12 +158,15 @@
size_50 = _construct_graph_of_size(50)
size_500 = _construct_graph_of_size(500)
- # Check reasonable graph construction time.
- self.assertLess(size_50, 5)
# Check construction time grows approx. linearly with size.
- e = 1.5 # Fudge factor to prevent flakiness.
+ e = 2 # Fudge factor to prevent flakiness.
self.assertLess(size_500, (10 * e) * size_50)
+ def test_no_mask_tracking(self):
+ x = keras.backend.placeholder((10, 10))
+ y = keras.layers.Masking(0.)(x)
+ self.assertTrue(y._keras_mask._keras_history_checked)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/layers/unified_gru_test.py b/tensorflow/python/keras/layers/unified_gru_test.py
index 8259643..b25007e 100644
--- a/tensorflow/python/keras/layers/unified_gru_test.py
+++ b/tensorflow/python/keras/layers/unified_gru_test.py
@@ -464,6 +464,29 @@
np.testing.assert_allclose(out7, out6, atol=1e-5)
+ def test_stateful_GRU_training(self):
+ # See b/123587692 for more context.
+ vocab_size = 20
+ embedding_dim = 10
+ batch_size = 8
+ timestep = 12
+ units = 5
+ x = np.random.randint(0, vocab_size, size=(batch_size, timestep))
+ y = np.random.randint(0, vocab_size, size=(batch_size, timestep))
+
+ model = keras.Sequential([
+ keras.layers.Embedding(vocab_size, embedding_dim,
+ batch_input_shape=[batch_size, timestep]),
+ keras.layers.UnifiedGRU(units,
+ return_sequences=True,
+ stateful=True),
+ keras.layers.Dense(vocab_size)
+ ])
+ model.compile(optimizer='adam',
+ loss='sparse_categorical_crossentropy',
+ run_eagerly=testing_utils.should_run_eagerly())
+ model.fit(x, y, epochs=1, shuffle=False)
+
class GRULayerGradientTapeTest(test.TestCase):
diff --git a/tensorflow/python/keras/layers/unified_lstm_test.py b/tensorflow/python/keras/layers/unified_lstm_test.py
index 375894b..08153db 100644
--- a/tensorflow/python/keras/layers/unified_lstm_test.py
+++ b/tensorflow/python/keras/layers/unified_lstm_test.py
@@ -633,6 +633,29 @@
self.assertAllClose(out7, out6, atol=1e-5)
+ def test_stateful_LSTM_training(self):
+ # See b/123587692 for more context.
+ vocab_size = 20
+ embedding_dim = 10
+ batch_size = 8
+ timestep = 12
+ units = 5
+ x = np.random.randint(0, vocab_size, size=(batch_size, timestep))
+ y = np.random.randint(0, vocab_size, size=(batch_size, timestep))
+
+ model = keras.Sequential([
+ keras.layers.Embedding(vocab_size, embedding_dim,
+ batch_input_shape=[batch_size, timestep]),
+ keras.layers.UnifiedLSTM(units,
+ return_sequences=True,
+ stateful=True),
+ keras.layers.Dense(vocab_size)
+ ])
+ model.compile(optimizer='adam',
+ loss='sparse_categorical_crossentropy',
+ run_eagerly=testing_utils.should_run_eagerly())
+ model.fit(x, y, epochs=1, shuffle=False)
+
class LSTMLayerGraphOnlyTest(test.TestCase):
diff --git a/tensorflow/python/keras/layers/wrappers.py b/tensorflow/python/keras/layers/wrappers.py
index c9424c9..a10db50 100644
--- a/tensorflow/python/keras/layers/wrappers.py
+++ b/tensorflow/python/keras/layers/wrappers.py
@@ -206,8 +206,12 @@
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
- assert len(input_shape) >= 3
- self.input_spec = InputSpec(shape=input_shape)
+ if len(input_shape) < 3:
+ raise ValueError(
+ '`TimeDistributed` Layer should be passed an `input_shape ` '
+ 'with at least 3 dimensions, received: ' + str(input_shape))
+ # Don't enforce the batch or time dimension.
+ self.input_spec = InputSpec(shape=[None, None] + input_shape[2:])
child_input_shape = [input_shape[0]] + input_shape[2:]
if not self.layer.built:
# The base layer class calls a conversion function on the input shape to
diff --git a/tensorflow/python/keras/layers/wrappers_test.py b/tensorflow/python/keras/layers/wrappers_test.py
index f3aa5c4..8a0b265 100644
--- a/tensorflow/python/keras/layers/wrappers_test.py
+++ b/tensorflow/python/keras/layers/wrappers_test.py
@@ -256,6 +256,28 @@
self.assertEqual((mask_outputs_val[1]).all(),
model_input.all())
+ def test_TimeDistributed_with_different_time_shapes(self):
+ time_dist = keras.layers.TimeDistributed(keras.layers.Dense(5))
+ ph_1 = keras.backend.placeholder(shape=(None, 10, 13))
+ out_1 = time_dist(ph_1)
+ self.assertEqual(out_1.shape.as_list(), [None, 10, 5])
+
+ ph_2 = keras.backend.placeholder(shape=(None, 1, 13))
+ out_2 = time_dist(ph_2)
+ self.assertEqual(out_2.shape.as_list(), [None, 1, 5])
+
+ ph_3 = keras.backend.placeholder(shape=(None, 1, 18))
+ with self.assertRaisesRegexp(ValueError, 'is incompatible with layer'):
+ time_dist(ph_3)
+
+ def test_TimeDistributed_with_invalid_dimensions(self):
+ time_dist = keras.layers.TimeDistributed(keras.layers.Dense(5))
+ ph = keras.backend.placeholder(shape=(None, 10))
+ with self.assertRaisesRegexp(
+ ValueError,
+ '`TimeDistributed` Layer should be passed an `input_shape `'):
+ time_dist(ph)
+
class BidirectionalTest(test.TestCase):
diff --git a/tensorflow/python/keras/losses.py b/tensorflow/python/keras/losses.py
index e83b805..366469d 100644
--- a/tensorflow/python/keras/losses.py
+++ b/tensorflow/python/keras/losses.py
@@ -23,6 +23,7 @@
import six
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import smart_cond
from tensorflow.python.keras import backend as K
@@ -31,6 +32,7 @@
from tensorflow.python.keras.utils.losses_utils import compute_weighted_loss
from tensorflow.python.keras.utils.tf_utils import is_tensor_or_variable
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops.losses import losses_impl
@@ -445,12 +447,18 @@
class Hinge(LossFunctionWrapper):
"""Computes the hinge loss between `y_true` and `y_pred`.
+ `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are
+ provided we will convert them to -1 or 1.
+
Usage:
```python
h = tf.losses.Hinge()
- loss = h([0., 1., 1.], [1., 0., 1.])
- print('Loss: ', loss.numpy()) # Loss: 0.66
+ loss = h([-1., 1., 1.], [0.6, -0.7, -0.5])
+
+ # loss = max(0, 1 - y_true * y_pred) = [1.6 + 1.7 + 1.5] / 3
+
+ print('Loss: ', loss.numpy()) # Loss: 1.6
```
Usage with tf.keras API:
@@ -471,12 +479,18 @@
class SquaredHinge(LossFunctionWrapper):
"""Computes the squared hinge loss between `y_true` and `y_pred`.
+ `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are
+ provided we will convert them to -1 or 1.
+
Usage:
```python
sh = tf.losses.SquaredHinge()
- loss = sh([0., 1., 1.], [1., 0., 1.])
- print('Loss: ', loss.numpy()) # Loss: 0.66
+ loss = sh([-1., 1., 1.], [0.6, -0.7, -0.5])
+
+ # loss = (max(0, 1 - y_true * y_pred))^2 = [1.6^2 + 1.7^2 + 1.5^2] / 3
+
+ print('Loss: ', loss.numpy()) # Loss: 2.566666
```
Usage with tf.keras API:
@@ -729,18 +743,55 @@
return K.mean(math_ops.squared_difference(first_log, second_log), axis=-1)
+def _maybe_convert_labels(y_true):
+ """Converts binary labels into -1/1."""
+ are_zeros = math_ops.equal(y_true, 0)
+ are_ones = math_ops.equal(y_true, 1)
+ is_binary = math_ops.reduce_all(math_ops.logical_or(are_zeros, are_ones))
+
+ def _convert_binary_labels():
+ # Convert the binary labels to -1 or 1.
+ return 2. * y_true - 1.
+
+ updated_y_true = smart_cond.smart_cond(is_binary,
+ _convert_binary_labels, lambda: y_true)
+ return updated_y_true
+
+
@keras_export('keras.metrics.squared_hinge', 'keras.losses.squared_hinge')
def squared_hinge(y_true, y_pred):
+ """Computes the squared hinge loss between `y_true` and `y_pred`.
+
+ Args:
+ y_true: The ground truth values. `y_true` values are expected to be -1 or 1.
+ If binary (0 or 1) labels are provided we will convert them to -1 or 1.
+ y_pred: The predicted values.
+
+ Returns:
+ Tensor with one scalar loss entry per sample.
+ """
y_pred = ops.convert_to_tensor(y_pred)
y_true = math_ops.cast(y_true, y_pred.dtype)
+ y_true = _maybe_convert_labels(y_true)
return K.mean(
math_ops.square(math_ops.maximum(1. - y_true * y_pred, 0.)), axis=-1)
@keras_export('keras.metrics.hinge', 'keras.losses.hinge')
def hinge(y_true, y_pred):
+ """Computes the hinge loss between `y_true` and `y_pred`.
+
+ Args:
+ y_true: The ground truth values. `y_true` values are expected to be -1 or 1.
+ If binary (0 or 1) labels are provided we will convert them to -1 or 1.
+ y_pred: The predicted values.
+
+ Returns:
+ Tensor with one scalar loss entry per sample.
+ """
y_pred = ops.convert_to_tensor(y_pred)
y_true = math_ops.cast(y_true, y_pred.dtype)
+ y_true = _maybe_convert_labels(y_true)
return K.mean(math_ops.maximum(1. - y_true * y_pred, 0.), axis=-1)
diff --git a/tensorflow/python/keras/losses_test.py b/tensorflow/python/keras/losses_test.py
index 04dd712..e722090 100644
--- a/tensorflow/python/keras/losses_test.py
+++ b/tensorflow/python/keras/losses_test.py
@@ -930,54 +930,93 @@
def test_unweighted(self):
hinge_obj = keras.losses.Hinge()
- y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
- y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
- shape=(2, 3),
- dtype=dtypes.float32)
+ y_true = constant_op.constant([[0, 1, 0, 1], [0, 0, 1, 1]])
+ y_pred = constant_op.constant([[-0.3, 0.2, -0.1, 1.6],
+ [-0.25, -1., 0.5, 0.6]])
+
+ # loss = max(0, 1-y_true * y_pred), where y_true is -1/1
+
+ # y_true = [[-1, 1, -1, 1], [-1, -1, 1, 1]]
+ # y_true * y_pred = [[0.3, 0.2, 0.1, 1.6], [0.25, 1, 0.5, 0.6]]
+ # 1 - y_true * y_pred = [[0.7, 0.8, 0.9, -0.6], [0.75, 0, 0.5, 0.4]]
+ # loss = [(0.7 + 0.8 + 0.9 + 0) / 4, (0.75 + 0 + 0.5 + 0.4) / 4]
+ # = [0.6, 0.4125]
+ # reduced loss = (0.6 + 0.4125) / 2
+
loss = hinge_obj(y_true, y_pred)
- self.assertAlmostEqual(self.evaluate(loss), 7.3333, 3)
+ self.assertAllClose(0.506, self.evaluate(loss), atol=1e-3)
def test_scalar_weighted(self):
hinge_obj = keras.losses.Hinge()
- y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
- y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
- shape=(2, 3),
- dtype=dtypes.float32)
+ y_true = constant_op.constant([[0, 1, 0, 1], [0, 0, 1, 1]])
+ y_pred = constant_op.constant([[-0.3, 0.2, -0.1, 1.6],
+ [-0.25, -1., 0.5, 0.6]])
+
+ # loss = max(0, 1-y_true * y_pred), where y_true is -1/1
+
+ # y_true = [[-1, 1, -1, 1], [-1, -1, 1, 1]]
+ # y_true * y_pred = [[0.3, 0.2, 0.1, 1.6], [0.25, 1, 0.5, 0.6]]
+ # 1 - y_true * y_pred = [[0.7, 0.8, 0.9, -0.6], [0.75, 0, 0.5, 0.4]]
+ # loss = [(0.7 + 0.8 + 0.9 + 0) / 4, (0.75 + 0 + 0.5 + 0.4) / 4]
+ # = [0.6, 0.4125]
+ # weighted_loss = [0.6 * 2.3, 0.4125 * 2.3]
+ # reduced loss = (0.6 + 0.4125) * 2.3 / 2
+
loss = hinge_obj(y_true, y_pred, sample_weight=2.3)
- self.assertAlmostEqual(self.evaluate(loss), 16.8666, 3)
+ self.assertAlmostEqual(self.evaluate(loss), 1.164, 3)
# Verify we get the same output when the same input is given
loss_2 = hinge_obj(y_true, y_pred, sample_weight=2.3)
- self.assertAlmostEqual(self.evaluate(loss), self.evaluate(loss_2), 3)
+ self.assertAllClose(self.evaluate(loss), self.evaluate(loss_2), 1e-3)
def test_sample_weighted(self):
hinge_obj = keras.losses.Hinge()
- y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
- y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
- shape=(2, 3),
- dtype=dtypes.float32)
+ y_true = constant_op.constant([[0, 1, 0, 1], [0, 0, 1, 1]])
+ y_pred = constant_op.constant([[-0.3, 0.2, -0.1, 1.6],
+ [-0.25, -1., 0.5, 0.6]])
+
+ # loss = max(0, 1-y_true * y_pred), where y_true is -1/1
+
+ # y_true = [[-1, 1, -1, 1], [-1, -1, 1, 1]]
+ # y_true * y_pred = [[0.3, 0.2, 0.1, 1.6], [0.25, 1, 0.5, 0.6]]
+ # 1 - y_true * y_pred = [[0.7, 0.8, 0.9, -0.6], [0.75, 0, 0.5, 0.4]]
+ # loss = [(0.7 + 0.8 + 0.9 + 0) / 4, (0.75 + 0 + 0.5 + 0.4) / 4]
+ # = [0.6, 0.4125]
+ # weighted loss = [0.6 * 1.2, 0.4125 * 3.4]
+ # reduced loss = (0.6 * 1.2 + 0.4125 * 3.4) / 2
+
sample_weight = constant_op.constant([1.2, 3.4], shape=(2, 1))
loss = hinge_obj(y_true, y_pred, sample_weight=sample_weight)
- self.assertAlmostEqual(self.evaluate(loss), 24.9333, 3)
+ self.assertAllClose(self.evaluate(loss), 1.061, 1e-3)
def test_timestep_weighted(self):
hinge_obj = keras.losses.Hinge()
- y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3, 1))
- y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
- shape=(2, 3, 1),
- dtype=dtypes.float32)
- sample_weight = constant_op.constant([3, 6, 5, 0, 4, 2], shape=(2, 3))
+ y_true = constant_op.constant([[0, 1, 0, 1], [0, 0, 1, 1]], shape=(2, 4, 1))
+ y_pred = constant_op.constant(
+ [[-0.3, 0.2, -0.1, 1.6], [-0.25, -1., 0.5, 0.6]], shape=(2, 4, 1))
+ sample_weight = constant_op.constant([3, 6, 5, 0, 4, 2, 1, 3], shape=(2, 4))
+
+ # loss = max(0, 1-y_true * y_pred), where y_true is -1/1
+
+ # y_true = [[[-1], [1], [-1], [1]], [[-1], [-1], [1], [1]]]
+ # y_true * y_pred = [[[0.3], [0.2], [0.1], [1.6]],
+ # [[0.25], [1], [0.5], [0.6]]]
+ # 1 - y_true * y_pred = [[[0.7], [0.8], [0.9], [-0.6]],
+ # [[0.75], [0], [0.5], [0.4]]]
+ # loss = [[0.7, 0.8, 0.9, 0], [0.75, 0, 0.5, 0.4]]
+ # weighted loss = [[2.1, 4.8, 4.5, 0], [3, 0, 0.5, 1.2]]
+ # reduced loss = (2.1 + 4.8 + 4.5 + 0 + 3 + 0 + 0.5 + 1.2) / 8
+
loss = hinge_obj(y_true, y_pred, sample_weight=sample_weight)
- self.assertAlmostEqual(self.evaluate(loss), 2.0, 3)
+ self.assertAllClose(self.evaluate(loss), 2.012, 1e-3)
def test_zero_weighted(self):
hinge_obj = keras.losses.Hinge()
- y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
- y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
- shape=(2, 3),
- dtype=dtypes.float32)
+ y_true = constant_op.constant([[0, 1, 0, 1], [0, 0, 1, 1]])
+ y_pred = constant_op.constant([[-0.3, 0.2, -0.1, 1.6],
+ [-0.25, -1., 0.5, 0.6]])
loss = hinge_obj(y_true, y_pred, sample_weight=0)
- self.assertAlmostEqual(self.evaluate(loss), 0., 3)
+ self.assertAllClose(self.evaluate(loss), 0., 1e-3)
@test_util.run_all_in_graph_and_eager_modes
@@ -991,26 +1030,46 @@
def test_unweighted(self):
sq_hinge_obj = keras.losses.SquaredHinge()
- y_true = constant_op.constant([1, 9, 2, -5], shape=(2, 2))
- y_pred = constant_op.constant([4, 8, 12, 8],
- shape=(2, 2),
- dtype=dtypes.float32)
+ y_true = constant_op.constant([[0, 1, 0, 1], [0, 0, 1, 1]])
+ y_pred = constant_op.constant([[-0.3, 0.2, -0.1, 1.6],
+ [-0.25, -1., 0.5, 0.6]])
- # Sq hinge = mean(square(max(1. - y_true * y_pred, 0.)), axis=-1)
- # (1. - y_true * y_pred) = [[1-4, 1-72], [1-24, 1+40]] = [0, 48]
- # sq(max(above val, 0)) = sq([[0, 0], [0, 41]) = [[0, 0], [0, 1681]]
- # Mean = [0, 840.5]. Reduced loss = (0 + 840.5)/2 = 420.25
+ # loss = max(0, 1-y_true * y_pred), where y_true is -1/1
+
+ # y_true = [[-1, 1, -1, 1], [-1, -1, 1, 1]]
+ # y_true * y_pred = [[0.3, 0.2, 0.1, 1.6], [0.25, 1, 0.5, 0.6]]
+ # 1 - y_true * y_pred = [[0.7, 0.8, 0.9, -0.6], [0.75, 0, 0.5, 0.4]]
+ # max(0, 1 - y_true * y_pred) = [[0.7, 0.8, 0.9, 0], [0.75, 0, 0.5, 0.4]]
+ # squared(max(0, 1 - y_true * y_pred)) = [[0.49, 0.64, 0.81, 0],
+ # [0.5625, 0, 0.25, 0.16]]
+ # loss = [(0.49 + 0.64 + 0.81 + 0) / 4, (0.5625 + 0 + 0.25 + 0.16) / 4]
+ # = [0.485, 0.2431]
+ # reduced loss = (0.485 + 0.2431) / 2
+
loss = sq_hinge_obj(y_true, y_pred)
- self.assertAlmostEqual(self.evaluate(loss), 420.25, 3)
+ self.assertAllClose(self.evaluate(loss), 0.364, 1e-3)
def test_scalar_weighted(self):
sq_hinge_obj = keras.losses.SquaredHinge()
- y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
- y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
- shape=(2, 3),
- dtype=dtypes.float32)
+ y_true = constant_op.constant([[0, 1, 0, 1], [0, 0, 1, 1]])
+ y_pred = constant_op.constant([[-0.3, 0.2, -0.1, 1.6],
+ [-0.25, -1., 0.5, 0.6]])
+
+ # loss = max(0, 1-y_true * y_pred), where y_true is -1/1
+
+ # y_true = [[-1, 1, -1, 1], [-1, -1, 1, 1]]
+ # y_true * y_pred = [[0.3, 0.2, 0.1, 1.6], [0.25, 1, 0.5, 0.6]]
+ # 1 - y_true * y_pred = [[0.7, 0.8, 0.9, -0.6], [0.75, 0, 0.5, 0.4]]
+ # max(0, 1 - y_true * y_pred) = [[0.7, 0.8, 0.9, 0], [0.75, 0, 0.5, 0.4]]
+ # squared(max(0, 1 - y_true * y_pred)) = [[0.49, 0.64, 0.81, 0],
+ # [0.5625, 0, 0.25, 0.16]]
+ # loss = [(0.49 + 0.64 + 0.81 + 0) / 4, (0.5625 + 0 + 0.25 + 0.16) / 4]
+ # = [0.485, 0.2431]
+ # weighted loss = [0.485 * 2.3, 0.2431 * 2.3]
+ # reduced loss = (0.485 + 0.2431) * 2.3 / 2
+
loss = sq_hinge_obj(y_true, y_pred, sample_weight=2.3)
- self.assertAlmostEqual(self.evaluate(loss), 647.833, 3)
+ self.assertAllClose(self.evaluate(loss), 0.837, 1e-3)
# Verify we get the same output when the same input is given
loss_2 = sq_hinge_obj(y_true, y_pred, sample_weight=2.3)
@@ -1018,32 +1077,55 @@
def test_sample_weighted(self):
sq_hinge_obj = keras.losses.SquaredHinge()
- y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
- y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
- shape=(2, 3),
- dtype=dtypes.float32)
- sample_weight = constant_op.constant([1.2, 3.4], shape=(2, 1))
+ y_true = constant_op.constant([[0, 1, 0, 1], [0, 0, 1, 1]])
+ y_pred = constant_op.constant([[-0.3, 0.2, -0.1, 1.6],
+ [-0.25, -1., 0.5, 0.6]])
+
+ # loss = max(0, 1-y_true * y_pred), where y_true is -1/1
+
+ # y_true = [[-1, 1, -1, 1], [-1, -1, 1, 1]]
+ # y_true * y_pred = [[0.3, 0.2, 0.1, 1.6], [0.25, 1, 0.5, 0.6]]
+ # 1 - y_true * y_pred = [[0.7, 0.8, 0.9, -0.6], [0.75, 0, 0.5, 0.4]]
+ # max(0, 1 - y_true * y_pred) = [[0.7, 0.8, 0.9, 0], [0.75, 0, 0.5, 0.4]]
+ # squared(max(0, 1 - y_true * y_pred)) = [[0.49, 0.64, 0.81, 0],
+ # [0.5625, 0, 0.25, 0.16]]
+ # loss = [(0.49 + 0.64 + 0.81 + 0) / 4, (0.5625 + 0 + 0.25 + 0.16) / 4]
+ # = [0.485, 0.2431]
+ # weighted loss = [0.485 * 1.2, 0.2431 * 3.4]
+ # reduced loss = (0.485 * 1.2 + 0.2431 * 3.4) / 2
+
+ sample_weight = constant_op.constant([1.2, 3.4])
loss = sq_hinge_obj(y_true, y_pred, sample_weight=sample_weight)
- self.assertAlmostEqual(self.evaluate(loss), 957.667, 3)
+ self.assertAllClose(self.evaluate(loss), 0.704, 1e-3)
def test_timestep_weighted(self):
sq_hinge_obj = keras.losses.SquaredHinge()
- y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3, 1))
- y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
- shape=(2, 3, 1),
- dtype=dtypes.float32)
- sample_weight = constant_op.constant([3, 6, 5, 0, 4, 2], shape=(2, 3))
+ y_true = constant_op.constant([[0, 1, 0, 1], [0, 0, 1, 1]], shape=(2, 4, 1))
+ y_pred = constant_op.constant(
+ [[-0.3, 0.2, -0.1, 1.6], [-0.25, -1., 0.5, 0.6]], shape=(2, 4, 1))
+ sample_weight = constant_op.constant([3, 6, 5, 0, 4, 2, 1, 3], shape=(2, 4))
+
+ # loss = max(0, 1-y_true * y_pred), where y_true is -1/1
+
+ # y_true = [[[-1], [1], [-1], [1]], [[-1], [-1], [1], [1]]]
+ # y_true * y_pred = [[[0.3], [0.2], [0.1], [1.6]],
+ # [[0.25], [1], [0.5], [0.6]]]
+ # 1 - y_true * y_pred = [[[0.7], [0.8], [0.9], [-0.6]],
+ # [[0.75], [0], [0.5], [0.4]]]
+ # loss = [[0.49, 0.64, 0.81, 0], [0.5625, 0, 0.25, 0.16]]
+ # weighted loss = [[1.47, 3.84, 4.05, 0], [2.25, 0, 0.25, 0.48]]
+ # reduced loss = (1.47 + 3.84 + 4.05 + 0 + 2.25 + 0 + 0.25 + 0.48) / 8
+
loss = sq_hinge_obj(y_true, y_pred, sample_weight=sample_weight)
- self.assertAlmostEqual(self.evaluate(loss), 6.0, 3)
+ self.assertAllClose(self.evaluate(loss), 1.542, 1e-3)
def test_zero_weighted(self):
sq_hinge_obj = keras.losses.SquaredHinge()
- y_true = constant_op.constant([1, 9, 2, -5, -2, 6], shape=(2, 3))
- y_pred = constant_op.constant([4, 8, 12, 8, 1, 3],
- shape=(2, 3),
- dtype=dtypes.float32)
+ y_true = constant_op.constant([[0, 1, 0, 1], [0, 0, 1, 1]])
+ y_pred = constant_op.constant([[-0.3, 0.2, -0.1, 1.6],
+ [-0.25, -1., 0.5, 0.6]])
loss = sq_hinge_obj(y_true, y_pred, sample_weight=0)
- self.assertAlmostEqual(self.evaluate(loss), 0., 3)
+ self.assertAllClose(self.evaluate(loss), 0., 1e-3)
@test_util.run_all_in_graph_and_eager_modes
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py
index f7ecbef..63ff5e6 100644
--- a/tensorflow/python/keras/metrics.py
+++ b/tensorflow/python/keras/metrics.py
@@ -651,7 +651,8 @@
For example, if `y_true` is [[0, 0, 1], [0, 1, 0]] and `y_pred` is
[[0.1, 0.9, 0.8], [0.05, 0.95, 0]] then the categorical accuracy is 1/2 or .5.
If the weights were specified as [0.7, 0.3] then the categorical accuracy
- would be .3.
+ would be .3. You can provide logits of classes as `y_pred`, since argmax of
+ logits and probabilities are same.
This metric creates two local variables, `total` and `count` that are used to
compute the frequency with which `y_pred` matches `y_true`. This frequency is
@@ -701,7 +702,8 @@
For example, if `y_true` is [[2], [1]] and `y_pred` is
[[0.1, 0.9, 0.8], [0.05, 0.95, 0]] then the categorical accuracy is 1/2 or .5.
If the weights were specified as [0.7, 0.3] then the categorical accuracy
- would be .3.
+ would be .3. You can provide logits of classes as `y_pred`, since argmax of
+ logits and probabilities are same.
This metric creates two local variables, `total` and `count` that are used to
compute the frequency with which `y_pred` matches `y_true`. This frequency is
@@ -1975,15 +1977,21 @@
class Hinge(MeanMetricWrapper):
"""Computes the hinge metric between `y_true` and `y_pred`.
- For example, if `y_true` is [0., 1., 1.], and `y_pred` is [1., 0., 1.]
- the hinge metric value is 0.66.
+ `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are
+ provided we will convert them to -1 or 1.
+
+ For example, if `y_true` is [-1., 1., 1.], and `y_pred` is [0.6, -0.7, -0.5]
+ the hinge metric value is 1.6.
Usage:
```python
m = tf.keras.metrics.Hinge()
- m.update_state([0., 1., 1.], [1., 0., 1.])
- print('Final result: ', m.result().numpy()) # Final result: 0.66
+ m.update_state([-1., 1., 1.], [0.6, -0.7, -0.5])
+
+ # result = max(0, 1-y_true * y_pred) = [1.6 + 1.7 + 1.5] / 3
+
+ print('Final result: ', m.result().numpy()) # Final result: 1.6
```
Usage with tf.keras API:
@@ -2002,15 +2010,21 @@
class SquaredHinge(MeanMetricWrapper):
"""Computes the squared hinge metric between `y_true` and `y_pred`.
- For example, if `y_true` is [0., 1., 1.], and `y_pred` is [1., 0., 1.]
- the squared hinge metric value is 0.66.
+ `y_true` values are expected to be -1 or 1. If binary (0 or 1) labels are
+ provided we will convert them to -1 or 1.
+
+ For example, if `y_true` is [-1., 1., 1.], and `y_pred` is [0.6, -0.7, -0.5]
+ the squared hinge metric value is 2.6.
Usage:
```python
m = tf.keras.metrics.SquaredHinge()
- m.update_state([0., 1., 1.], [1., 0., 1.])
- print('Final result: ', m.result().numpy()) # Final result: 0.66
+ m.update_state([-1., 1., 1.], [0.6, -0.7, -0.5])
+
+ # result = max(0, 1-y_true * y_pred) = [1.6^2 + 1.7^2 + 1.5^2] / 3
+
+ print('Final result: ', m.result().numpy()) # Final result: 2.6
```
Usage with tf.keras API:
@@ -2614,6 +2628,63 @@
axis=axis)
+class SumOverBatchSize(Reduce):
+ """Computes the weighted sum over batch size of the given values.
+
+ For example, if values is [1, 3, 5, 7] then the metric value is 4.
+ If the weights were specified as [1, 1, 0, 0] then the value would be 1.
+
+ This metric creates two variables, `total` and `count` that are used to
+ compute the average of `values`. This average is ultimately returned as sum
+ over batch size which is an idempotent operation that simply divides `total`
+ by `count`.
+
+ If `sample_weight` is `None`, weights default to 1. Use `sample_weight` of 0
+ to mask values.
+ """
+
+ def __init__(self, name='sum_over_batch_size', dtype=None):
+ super(SumOverBatchSize, self).__init__(
+ reduction=metrics_utils.Reduction.SUM_OVER_BATCH_SIZE,
+ name=name,
+ dtype=dtype)
+
+
+class SumOverBatchSizeMetricWrapper(SumOverBatchSize):
+ """Wraps a function with the `SumOverBatchSizeMetricWrapper` metric."""
+
+ def __init__(self, fn, name=None, dtype=None, **kwargs):
+ """Creates a `SumOverBatchSizeMetricWrapper` instance.
+
+ Args:
+ fn: The metric function to wrap, with signature `fn(y_true, y_pred,
+ **kwargs)`.
+ name: (Optional) string name of the metric instance.
+ dtype: (Optional) data type of the metric result.
+ **kwargs: The keyword arguments that are passed on to `fn`.
+ """
+ super(SumOverBatchSizeMetricWrapper, self).__init__(name=name, dtype=dtype)
+ self._fn = fn
+ self._fn_kwargs = kwargs
+
+ def update_state(self, y_true, y_pred, sample_weight=None):
+ y_true = math_ops.cast(y_true, self._dtype)
+ y_pred = math_ops.cast(y_pred, self._dtype)
+ y_pred, y_true, sample_weight = squeeze_or_expand_dimensions(
+ y_pred, y_true, sample_weight)
+
+ matches = self._fn(y_true, y_pred, **self._fn_kwargs)
+ return super(SumOverBatchSizeMetricWrapper, self).update_state(
+ matches, sample_weight=sample_weight)
+
+ def get_config(self):
+ config = {}
+ for k, v in six.iteritems(self._fn_kwargs):
+ config[k] = K.eval(v) if is_tensor_or_variable(v) else v
+ base_config = super(SumOverBatchSizeMetricWrapper, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
+
def accuracy(y_true, y_pred):
y_pred.get_shape().assert_is_compatible_with(y_true.get_shape())
if y_true.dtype != y_pred.dtype:
diff --git a/tensorflow/python/keras/metrics_correctness_test.py b/tensorflow/python/keras/metrics_correctness_test.py
new file mode 100644
index 0000000..b2385aa
--- /dev/null
+++ b/tensorflow/python/keras/metrics_correctness_test.py
@@ -0,0 +1,326 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests metrics correctness using Keras model."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.keras import keras_parameterized
+from tensorflow.python.keras import layers
+from tensorflow.python.keras import metrics
+from tensorflow.python.keras import testing_utils
+from tensorflow.python.platform import test
+
+
+@keras_parameterized.run_with_all_model_types(exclude_models=['sequential'])
+@keras_parameterized.run_all_keras_modes
+class TestMetricsCorrectnessMultiIO(keras_parameterized.TestCase):
+ # TODO(psv): Remove the run_eagerly checks here when b/123082095 is fixed.
+
+ def _get_multi_io_model(self):
+ inp_1 = layers.Input(shape=(1,), name='input_1')
+ inp_2 = layers.Input(shape=(1,), name='input_2')
+ x = layers.Dense(3, kernel_initializer='ones', trainable=False)
+ out_1 = layers.Dense(
+ 1, kernel_initializer='ones', name='output_1', trainable=False)
+ out_2 = layers.Dense(
+ 1, kernel_initializer='ones', name='output_2', trainable=False)
+
+ branch_a = [inp_1, x, out_1]
+ branch_b = [inp_2, x, out_2]
+ model = testing_utils.get_multi_io_model(branch_a, branch_b)
+ model.compile(
+ optimizer='rmsprop',
+ loss='mse',
+ metrics=[metrics.MeanSquaredError()],
+ weighted_metrics=[metrics.MeanSquaredError()],
+ run_eagerly=testing_utils.should_run_eagerly())
+ return model
+
+ def _custom_generator(self):
+ batch_size = 2
+ num_samples = 4
+ inputs = np.asarray([[1.], [2.], [3.], [4.]])
+ targets = np.asarray([[2.], [4.], [6.], [8.]])
+ w1 = np.asarray([2., 3., 4., 5.])
+ w2 = np.asarray([3.5, 2.5, 1.5, 0.5])
+ i = 0
+ while True:
+ batch_index = i * batch_size % num_samples
+ i += 1
+ start = batch_index
+ end = start + batch_size
+ x = [inputs[start:end], inputs[start:end]]
+ y = [targets[start:end], targets[start:end]]
+ w = [w1[start:end], w2[start:end]]
+ yield x, y, w
+
+ def setUp(self):
+ super(TestMetricsCorrectnessMultiIO, self).setUp()
+ self.x = np.asarray([[1.], [2.], [3.], [4.]])
+ self.y = np.asarray([[2.], [4.], [6.], [8.]])
+ self.weights_1 = np.asarray([2., 3., 4., 5.])
+ self.weights_2 = np.asarray([3.5, 2.5, 1.5, 0.5])
+
+ # y_true = [[2.], [4.], [6.], [8.]], y_pred = [[3.], [6.], [9.], [12.]]
+
+ # Metric `output_1`, `output_2`:
+ # Total = ((3 - 2)^2 + (6 - 4)^2) + ((9 - 6)^2 + (12 - 8)^2) = 30,
+ # Count = 2 + 2
+ # Result = 7.5
+
+ # Weighted metric `output_1`:
+ # Total = ((3 - 2)^2 * 2 + (6 - 4)^2 * 3) +
+ # ((9 - 6)^2 * 4 + (12 - 8)^2 * 5)
+ # = 130
+ # Count = (2 + 3) + (4 + 5)
+ # Result = 9.2857141
+
+ # Weighted metric `output_2`:
+ # Total = ((3 - 2)^2 * 3.5 + (6 - 4)^2 * 2.5) +
+ # ((9 - 6)^2 * 1.5 + (12 - 8)^2 * 0.5)
+ # = 35
+ # Count = (3.5 + 2.5) + (1.5 + 0.5)
+ # Result = 4.375
+
+ # Loss `output_1`:
+ # Total = ((3 - 2)^2 * 2 + (6 - 4)^2 * 3) +
+ # ((9 - 6)^2 * 4 + (12 - 8)^2 * 5)
+ # = 130
+ # Count = 2 + 2
+ # Result = 32.5
+
+ # Loss `output_2`:
+ # Total = ((3 - 2)^2 * 3.5 + (6 - 4)^2 * 2.5) +
+ # ((9 - 6)^2 * 1.5 + (12 - 8)^2 * 0.5)
+ # = 35
+ # Count = 2 + 2
+ # Result = 8.75
+
+ # Total loss = 32.5 + 8.75 = 41.25
+
+ self.expected_fit_result = {
+ 'output_1_mean_squared_error': [7.5, 7.5],
+ 'output_2_mean_squared_error': [7.5, 7.5],
+ 'output_1_weighted_mean_squared_error': [9.286, 9.286],
+ 'output_2_weighted_mean_squared_error': [4.375, 4.375],
+ 'loss': [41.25, 41.25]
+ }
+
+ # In the order: 'loss', 'output_1_loss', 'output_2_loss',
+ # 'output_1_mean_squared_error', 'output_1_weighted_mean_squared_error',
+ # 'output_2_mean_squared_error', 'output_2_weighted_mean_squared_error'
+ self.expected_batch_result = [41.25, 32.5, 8.75, 7.5, 9.286, 7.5, 4.375]
+
+ def test_fit(self):
+ model = self._get_multi_io_model()
+ history = model.fit([self.x, self.x], [self.y, self.y],
+ sample_weight={
+ 'output_1': self.weights_1,
+ 'output_2': self.weights_2,
+ },
+ batch_size=2,
+ epochs=2,
+ shuffle=False)
+
+ if not model.run_eagerly:
+ self.expected_fit_result['output_1_loss'] = [32.5, 32.5]
+ self.expected_fit_result['output_2_loss'] = [8.75, 8.75]
+
+ for key, value in self.expected_fit_result.items():
+ self.assertAllClose(history.history[key], value, 1e-3)
+
+ def test_eval(self):
+ model = self._get_multi_io_model()
+ eval_result = model.evaluate([self.x, self.x], [self.y, self.y],
+ batch_size=2,
+ sample_weight={
+ 'output_1': self.weights_1,
+ 'output_2': self.weights_2,
+ })
+
+ if model.run_eagerly:
+ self.expected_batch_result = [41.25, 58, 10.75, 7.5, 9.286, 7.5, 4.375]
+ self.assertAllClose(eval_result, self.expected_batch_result, 1e-3)
+
+ if model.run_eagerly:
+ return
+ # Verify that metric value is same with arbitrary weights and batch size.
+ x = np.random.random((50, 1))
+ y = np.random.random((50, 1))
+ w = np.random.random((50,))
+ mse1 = model.evaluate([x, x], [y, y], sample_weight=[w, w], batch_size=5)[3]
+ mse2 = model.evaluate([x, x], [y, y], sample_weight=[w, w],
+ batch_size=10)[3]
+ self.assertAllClose(mse1, mse2, 1e-3)
+
+ def test_train_on_batch(self):
+ model = self._get_multi_io_model()
+ result = model.train_on_batch([self.x, self.x], [self.y, self.y],
+ sample_weight={
+ 'output_1': self.weights_1,
+ 'output_2': self.weights_2,
+ })
+ self.assertAllClose(result, self.expected_batch_result, 1e-3)
+
+ def test_test_on_batch(self):
+ model = self._get_multi_io_model()
+ result = model.test_on_batch([self.x, self.x], [self.y, self.y],
+ sample_weight={
+ 'output_1': self.weights_1,
+ 'output_2': self.weights_2,
+ })
+ self.assertAllClose(result, self.expected_batch_result, 1e-3)
+
+ def test_fit_generator(self):
+ model = self._get_multi_io_model()
+ history = model.fit_generator(
+ self._custom_generator(), steps_per_epoch=2, epochs=2)
+
+ if not model.run_eagerly:
+ self.expected_fit_result['output_1_loss'] = [32.5, 32.5]
+ self.expected_fit_result['output_2_loss'] = [8.75, 8.75]
+ for key, value in self.expected_fit_result.items():
+ self.assertAllClose(history.history[key], value, 1e-3)
+
+ def test_eval_generator(self):
+ model = self._get_multi_io_model()
+ eval_result = model.evaluate_generator(self._custom_generator(), steps=2)
+ if model.run_eagerly:
+ self.expected_batch_result = [41.25, 58, 10.75, 7.5, 9.286, 7.5, 4.375]
+ self.assertAllClose(eval_result, self.expected_batch_result, 1e-3)
+
+
+@keras_parameterized.run_with_all_model_types
+@keras_parameterized.run_all_keras_modes
+class TestMetricsCorrectnessSingleIO(keras_parameterized.TestCase):
+
+ def _get_model(self):
+ x = layers.Dense(3, kernel_initializer='ones', trainable=False)
+ out = layers.Dense(
+ 1, kernel_initializer='ones', name='output', trainable=False)
+ model = testing_utils.get_model_from_layers([x, out], input_shape=(1,))
+ model.compile(
+ optimizer='rmsprop',
+ loss='mse',
+ metrics=[metrics.MeanSquaredError()],
+ weighted_metrics=[metrics.MeanSquaredError()],
+ run_eagerly=testing_utils.should_run_eagerly())
+ return model
+
+ def _custom_generator(self):
+ batch_size = 2
+ num_samples = 4
+ x = np.asarray([[1.], [2.], [3.], [4.]])
+ y = np.asarray([[2.], [4.], [6.], [8.]])
+ w = np.asarray([2., 3., 4., 5.])
+ i = 0
+ while True:
+ batch_index = i * batch_size % num_samples
+ i += 1
+ start = batch_index
+ end = start + batch_size
+ yield x[start:end], y[start:end], w[start:end]
+
+ def setUp(self):
+ super(TestMetricsCorrectnessSingleIO, self).setUp()
+ self.x = np.asarray([[1.], [2.], [3.], [4.]])
+ self.y = np.asarray([[2.], [4.], [6.], [8.]])
+ self.weights = np.asarray([2., 3., 4., 5.])
+
+ # y_true = [[2.], [4.], [6.], [8.]], y_pred = [[3.], [6.], [9.], [12.]]
+
+ # Metric:
+ # Total = ((3 - 2)^2 + (6 - 4)^2) + ((9 - 6)^2 + (12 - 8)^2) = 30,
+ # Count = 2 + 2
+ # Result = 7.5
+
+ # Weighted metric:
+ # Total = ((3 - 2)^2 * 2 + (6 - 4)^2 * 3) +
+ # ((9 - 6)^2 * 4 + (12 - 8)^2 * 5)
+ # = 130
+ # Count = (2 + 3) + (4 + 5)
+ # Result = 9.2857141
+
+ # Total loss:
+ # Total = ((3 - 2)^2 * 2 + (6 - 4)^2 * 3) +
+ # ((9 - 6)^2 * 4 + (12 - 8)^2 * 5)
+ # = 130,
+ # Count = 2 + 2
+ # Result = 32.5
+
+ self.expected_fit_result = {
+ 'mean_squared_error': [7.5, 7.5],
+ 'weighted_mean_squared_error': [9.286, 9.286],
+ 'loss': [32.5, 32.5]
+ }
+
+ # In the order: 'loss', 'mean_squared_error', 'weighted_mean_squared_error'
+ self.expected_batch_result = [32.5, 7.5, 9.286]
+
+ def test_fit(self):
+ model = self._get_model()
+ history = model.fit(
+ self.x,
+ self.y,
+ sample_weight=self.weights,
+ batch_size=2,
+ epochs=2,
+ shuffle=False)
+ for key, value in self.expected_fit_result.items():
+ self.assertAllClose(history.history[key], value, 1e-3)
+
+ def test_eval(self):
+ model = self._get_model()
+ eval_result = model.evaluate(
+ self.x, self.y, batch_size=2, sample_weight=self.weights)
+ self.assertAllClose(eval_result, self.expected_batch_result, 1e-3)
+
+ # Verify that metric value is same with arbitrary weights and batch size.
+ x = np.random.random((50, 1))
+ y = np.random.random((50, 1))
+ w = np.random.random((50,))
+ mse1 = model.evaluate(x, y, sample_weight=w, batch_size=5)[1]
+ mse2 = model.evaluate(x, y, sample_weight=w, batch_size=10)[1]
+ self.assertAllClose(mse1, mse2, 1e-3)
+
+ def test_train_on_batch(self):
+ model = self._get_model()
+ result = model.train_on_batch(self.x, self.y, sample_weight=self.weights)
+ self.assertAllClose(result, self.expected_batch_result, 1e-3)
+
+ def test_test_on_batch(self):
+ model = self._get_model()
+ result = model.test_on_batch(self.x, self.y, sample_weight=self.weights)
+ self.assertAllClose(result, self.expected_batch_result, 1e-3)
+
+ def test_fit_generator(self):
+ model = self._get_model()
+ history = model.fit_generator(
+ self._custom_generator(), steps_per_epoch=2, epochs=2)
+ for key, value in self.expected_fit_result.items():
+ self.assertAllClose(history.history[key], value, 1e-3)
+
+ def test_eval_generator(self):
+ model = self._get_model()
+ eval_result = model.evaluate_generator(self._custom_generator(), steps=2)
+ self.assertAllClose(eval_result, self.expected_batch_result, 1e-3)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py
index 7d1f888..4143066 100644
--- a/tensorflow/python/keras/metrics_test.py
+++ b/tensorflow/python/keras/metrics_test.py
@@ -685,26 +685,43 @@
def test_unweighted(self):
hinge_obj = metrics.Hinge()
self.evaluate(variables.variables_initializer(hinge_obj.variables))
- y_true = constant_op.constant(((0, 1, 0, 1, 0), (0, 0, 1, 1, 1),
- (1, 1, 1, 1, 0), (0, 0, 0, 0, 1)))
- y_pred = constant_op.constant(((0, 0, 1, 1, 0), (1, 1, 1, 1, 1),
- (0, 1, 0, 1, 0), (1, 1, 1, 1, 1)))
+ y_true = constant_op.constant([[0, 1, 0, 1], [0, 0, 1, 1]])
+ y_pred = constant_op.constant([[-0.3, 0.2, -0.1, 1.6],
+ [-0.25, -1., 0.5, 0.6]])
+
+ # metric = max(0, 1-y_true * y_pred), where y_true is -1/1
+
+ # y_true = [[-1, 1, -1, 1], [-1, -1, 1, 1]]
+ # y_true * y_pred = [[0.3, 0.2, 0.1, 1.6], [0.25, 1, 0.5, 0.6]]
+ # 1 - y_true * y_pred = [[0.7, 0.8, 0.9, -0.6], [0.75, 0, 0.5, 0.4]]
+ # metric = [(0.7 + 0.8 + 0.9 + 0) / 4, (0.75 + 0 + 0.5 + 0.4) / 4]
+ # = [0.6, 0.4125]
+ # reduced metric = (0.6 + 0.4125) / 2
update_op = hinge_obj.update_state(y_true, y_pred)
self.evaluate(update_op)
result = hinge_obj.result()
- self.assertAllClose(0.65, result, atol=1e-5)
+ self.assertAllClose(0.506, result, atol=1e-3)
def test_weighted(self):
hinge_obj = metrics.Hinge()
self.evaluate(variables.variables_initializer(hinge_obj.variables))
- y_true = constant_op.constant(((0, 1, 0, 1, 0), (0, 0, 1, 1, 1),
- (1, 1, 1, 1, 0), (0, 0, 0, 0, 1)))
- y_pred = constant_op.constant(((0, 0, 1, 1, 0), (1, 1, 1, 1, 1),
- (0, 1, 0, 1, 0), (1, 1, 1, 1, 1)))
- sample_weight = constant_op.constant((1., 1.5, 2., 2.5))
+ y_true = constant_op.constant([[-1, 1, -1, 1], [-1, -1, 1, 1]])
+ y_pred = constant_op.constant([[-0.3, 0.2, -0.1, 1.6],
+ [-0.25, -1., 0.5, 0.6]])
+ sample_weight = constant_op.constant([1.5, 2.])
+
+ # metric = max(0, 1-y_true * y_pred), where y_true is -1/1
+
+ # y_true * y_pred = [[0.3, 0.2, 0.1, 1.6], [0.25, 1, 0.5, 0.6]]
+ # 1 - y_true * y_pred = [[0.7, 0.8, 0.9, -0.6], [0.75, 0, 0.5, 0.4]]
+ # metric = [(0.7 + 0.8 + 0.9 + 0) / 4, (0.75 + 0 + 0.5 + 0.4) / 4]
+ # = [0.6, 0.4125]
+ # weighted metric = [0.6 * 1.5, 0.4125 * 2]
+ # reduced metric = (0.6 * 1.5 + 0.4125 * 2) / (1.5 + 2)
+
result = hinge_obj(y_true, y_pred, sample_weight=sample_weight)
- self.assertAllClose(0.65714, self.evaluate(result), atol=1e-5)
+ self.assertAllClose(0.493, self.evaluate(result), atol=1e-3)
@test_util.run_all_in_graph_and_eager_modes
@@ -723,26 +740,49 @@
def test_unweighted(self):
sq_hinge_obj = metrics.SquaredHinge()
self.evaluate(variables.variables_initializer(sq_hinge_obj.variables))
- y_true = constant_op.constant(((0, 1, 0, 1, 0), (0, 0, 1, 1, 1),
- (1, 1, 1, 1, 0), (0, 0, 0, 0, 1)))
- y_pred = constant_op.constant(((0, 0, 1, 1, 0), (1, 1, 1, 1, 1),
- (0, 1, 0, 1, 0), (1, 1, 1, 1, 1)))
+ y_true = constant_op.constant([[0, 1, 0, 1], [0, 0, 1, 1]])
+ y_pred = constant_op.constant([[-0.3, 0.2, -0.1, 1.6],
+ [-0.25, -1., 0.5, 0.6]])
+
+ # metric = max(0, 1-y_true * y_pred), where y_true is -1/1
+
+ # y_true = [[-1, 1, -1, 1], [-1, -1, 1, 1]]
+ # y_true * y_pred = [[0.3, 0.2, 0.1, 1.6], [0.25, 1, 0.5, 0.6]]
+ # 1 - y_true * y_pred = [[0.7, 0.8, 0.9, -0.6], [0.75, 0, 0.5, 0.4]]
+ # max(0, 1 - y_true * y_pred) = [[0.7, 0.8, 0.9, 0], [0.75, 0, 0.5, 0.4]]
+ # squared(max(0, 1 - y_true * y_pred)) = [[0.49, 0.64, 0.81, 0],
+ # [0.5625, 0, 0.25, 0.16]]
+ # metric = [(0.49 + 0.64 + 0.81 + 0) / 4, (0.5625 + 0 + 0.25 + 0.16) / 4]
+ # = [0.485, 0.2431]
+ # reduced metric = (0.485 + 0.2431) / 2
update_op = sq_hinge_obj.update_state(y_true, y_pred)
self.evaluate(update_op)
result = sq_hinge_obj.result()
- self.assertAllClose(0.65, result, atol=1e-5)
+ self.assertAllClose(0.364, result, atol=1e-3)
def test_weighted(self):
sq_hinge_obj = metrics.SquaredHinge()
self.evaluate(variables.variables_initializer(sq_hinge_obj.variables))
- y_true = constant_op.constant(((0, 1, 0, 1, 0), (0, 0, 1, 1, 1),
- (1, 1, 1, 1, 0), (0, 0, 0, 0, 1)))
- y_pred = constant_op.constant(((0, 0, 1, 1, 0), (1, 1, 1, 1, 1),
- (0, 1, 0, 1, 0), (1, 1, 1, 1, 1)))
- sample_weight = constant_op.constant((1., 1.5, 2., 2.5))
+ y_true = constant_op.constant([[-1, 1, -1, 1], [-1, -1, 1, 1]])
+ y_pred = constant_op.constant([[-0.3, 0.2, -0.1, 1.6],
+ [-0.25, -1., 0.5, 0.6]])
+ sample_weight = constant_op.constant([1.5, 2.])
+
+ # metric = max(0, 1-y_true * y_pred), where y_true is -1/1
+
+ # y_true * y_pred = [[0.3, 0.2, 0.1, 1.6], [0.25, 1, 0.5, 0.6]]
+ # 1 - y_true * y_pred = [[0.7, 0.8, 0.9, -0.6], [0.75, 0, 0.5, 0.4]]
+ # max(0, 1 - y_true * y_pred) = [[0.7, 0.8, 0.9, 0], [0.75, 0, 0.5, 0.4]]
+ # squared(max(0, 1 - y_true * y_pred)) = [[0.49, 0.64, 0.81, 0],
+ # [0.5625, 0, 0.25, 0.16]]
+ # metric = [(0.49 + 0.64 + 0.81 + 0) / 4, (0.5625 + 0 + 0.25 + 0.16) / 4]
+ # = [0.485, 0.2431]
+ # weighted metric = [0.485 * 1.5, 0.2431 * 2]
+ # reduced metric = (0.485 * 1.5 + 0.2431 * 2) / (1.5 + 2)
+
result = sq_hinge_obj(y_true, y_pred, sample_weight=sample_weight)
- self.assertAllClose(0.65714, self.evaluate(result), atol=1e-5)
+ self.assertAllClose(0.347, self.evaluate(result), atol=1e-3)
@test_util.run_all_in_graph_and_eager_modes
diff --git a/tensorflow/python/keras/ops.py b/tensorflow/python/keras/ops.py
index bc14eef..44e0228 100644
--- a/tensorflow/python/keras/ops.py
+++ b/tensorflow/python/keras/ops.py
@@ -56,22 +56,33 @@
keras_export("keras.initializers.Initializer", v1=[])(
init_ops_v2.Initializer)
-keras_export("keras.initializers.Zeros", v1=[])(
- init_ops_v2.Zeros)
-keras_export("keras.initializers.Ones", v1=[])(
- init_ops_v2.Ones)
-keras_export("keras.initializers.Constant", v1=[])(
- init_ops_v2.Constant)
+keras_export(
+ "keras.initializers.Zeros", "keras.initializers.zeros", v1=[])(
+ init_ops_v2.Zeros)
+keras_export(
+ "keras.initializers.Ones", "keras.initializers.ones", v1=[])(
+ init_ops_v2.Ones)
+keras_export(
+ "keras.initializers.Constant", "keras.initializers.constant", v1=[])(
+ init_ops_v2.Constant)
keras_export("keras.initializers.VarianceScaling", v1=[])(
init_ops_v2.VarianceScaling)
-keras_export("keras.initializers.Orthogonal", v1=[])(
- init_ops_v2.Orthogonal)
-keras_export("keras.initializers.Identity", v1=[])(
- init_ops_v2.Identity)
-keras_export("keras.initializers.GlorotUniform", v1=[])(
- init_ops_v2.GlorotUniform)
-keras_export("keras.initializers.GlorotNormal", v1=[])(
- init_ops_v2.GlorotNormal)
+keras_export(
+ "keras.initializers.Orthogonal", "keras.initializers.orthogonal", v1=[])(
+ init_ops_v2.Orthogonal)
+keras_export(
+ "keras.initializers.Identity", "keras.initializers.identity", v1=[])(
+ init_ops_v2.Identity)
+keras_export(
+ "keras.initializers.GlorotUniform",
+ "keras.initializers.glorot_uniform",
+ v1=[])(
+ init_ops_v2.GlorotUniform)
+keras_export(
+ "keras.initializers.GlorotNormal",
+ "keras.initializers.glorot_normal",
+ v1=[])(
+ init_ops_v2.GlorotNormal)
keras_export("keras.initializers.lecun_normal", v1=[])(
init_ops_v2.lecun_normal)
keras_export("keras.initializers.lecun_uniform", v1=[])(
diff --git a/tensorflow/python/keras/optimizer_v2/learning_rate_schedule.py b/tensorflow/python/keras/optimizer_v2/learning_rate_schedule.py
index a182d74..c44263b 100644
--- a/tensorflow/python/keras/optimizer_v2/learning_rate_schedule.py
+++ b/tensorflow/python/keras/optimizer_v2/learning_rate_schedule.py
@@ -631,8 +631,7 @@
}
-@keras_export("keras.experimental.CosineDecayRestarts",
- v1=[])
+@keras_export("keras.experimental.CosineDecayRestarts")
class CosineDecayRestarts(LearningRateSchedule):
"""A LearningRateSchedule that uses a cosine decay schedule with restarts."""
@@ -761,8 +760,7 @@
}
-@keras_export("keras.experimental.LinearCosineDecay",
- v1=[])
+@keras_export("keras.experimental.LinearCosineDecay")
class LinearCosineDecay(LearningRateSchedule):
"""A LearningRateSchedule that uses a linear cosine decay schedule."""
@@ -879,8 +877,7 @@
}
-@keras_export("keras.experimental.NoisyLinearCosineDecay",
- v1=[])
+@keras_export("keras.experimental.NoisyLinearCosineDecay")
class NoisyLinearCosineDecay(LearningRateSchedule):
"""A LearningRateSchedule that uses a noisy linear cosine decay schedule."""
diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py
index bf6dcaa..b701c98 100644
--- a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py
+++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py
@@ -25,7 +25,6 @@
import six
-from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx
from tensorflow.python.distribute import reduce_util as ds_reduce_util
from tensorflow.python.distribute import values as distributed_values
@@ -126,7 +125,26 @@
opt.apply_gradients(capped_grads_and_vars)
```
+ ### Use with `tf.distribute.Strategy`.
+
+ This optimizer class is `tf.distribute.Strategy` aware, which means it
+ automatically sums gradients across all replicas. To average gradients,
+ you divide your loss by the global batch size, which is done automatically
+ if you use a member of `tf.keras.losses` or `tf.losses`. See the
+ `reduction` argument of your loss which should be set to
+ `tf.losses.Reduction.SUM_OVER_BATCH_SIZE` for averaging or
+ `tf.losses.Reduction.SUM` for not.
+
+ If you are not using these and you want to average gradients, you should use
+ `tf.math.reduce_sum` to add up your per-example losses and then divide by the
+ global batch size. Note that when using `tf.distribute.Strategy`, the first
+ component of a tensor's shape is the *replica-local* batch size, which is off
+ by a factor equal to the number of replicas being used to compute a single
+ step. As a result, using `tf.math.reduce_mean` will give the wrong answer,
+ resulting in gradients that can be many times too big.
+
### Variable Constraint
+
All Keras optimizers respect variable constraints. If constraint function is
passed to any variable, the constraint will be applied to the variable after
the gradient has been applied to the variable.
@@ -174,15 +192,13 @@
```
### Write a customized optimizer.
+ If you intend to create your own optimization algorithm, simply inherit from
+ this class and override the following methods:
- This optimizer class updates variables from gradients and is
- tf.distribute.Strategy aware. If you intend to create your own optimization
- algorithm, simply inherit from this class and override the following methods:
- resource_apply_dense (update variable given gradient tensor is dense)
- resource_apply_sparse (update variable given gradient tensor is sparse)
- create_slots (if your optimizer algorithm requires additional variables)
- get_config (serialization of the optimizer, include all hyper parameters)
-
"""
def __init__(self, name, **kwargs):
@@ -310,7 +326,6 @@
with backprop.GradientTape() as tape:
tape.watch(var_list)
loss_value = loss()
- loss_value = self._scale_loss(loss_value)
grads = tape.gradient(loss_value, var_list, grad_loss)
if hasattr(self, "clipnorm"):
@@ -329,14 +344,6 @@
return grads_and_vars
- @staticmethod
- def _scale_loss(loss_value):
- if distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN:
- num_replicas = distribute_ctx.get_strategy().num_replicas_in_sync
- if num_replicas > 1:
- loss_value *= (1. / num_replicas)
- return loss_value
-
def get_gradients(self, loss, params):
"""Returns gradients of `loss` with respect to `params`.
@@ -351,7 +358,6 @@
ValueError: In case any gradient cannot be computed (e.g. if gradient
function not implemented).
"""
- loss = self._scale_loss(loss)
grads = gradients.gradients(loss, params)
if None in grads:
raise ValueError("An operation has `None` for gradient. "
diff --git a/tensorflow/python/keras/regularizers_test.py b/tensorflow/python/keras/regularizers_test.py
index 3d6b259..3aca0c7 100644
--- a/tensorflow/python/keras/regularizers_test.py
+++ b/tensorflow/python/keras/regularizers_test.py
@@ -18,9 +18,11 @@
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
+
from tensorflow.python import keras
-from tensorflow.python.keras import testing_utils
from tensorflow.python.framework import test_util
+from tensorflow.python.keras import testing_utils
from tensorflow.python.platform import test
@@ -28,50 +30,53 @@
NUM_CLASSES = 2
-def get_data():
- (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
- train_samples=10,
- test_samples=10,
- input_shape=(DATA_DIM,),
- num_classes=NUM_CLASSES)
- y_train = keras.utils.to_categorical(y_train, NUM_CLASSES)
- y_test = keras.utils.to_categorical(y_test, NUM_CLASSES)
- return (x_train, y_train), (x_test, y_test)
+class KerasRegularizersTest(test.TestCase, parameterized.TestCase):
+ def create_model(self, kernel_regularizer=None, activity_regularizer=None):
+ model = keras.models.Sequential()
+ model.add(keras.layers.Dense(NUM_CLASSES,
+ kernel_regularizer=kernel_regularizer,
+ activity_regularizer=activity_regularizer,
+ input_shape=(DATA_DIM,)))
+ return model
-def create_model(kernel_regularizer=None, activity_regularizer=None):
- model = keras.models.Sequential()
- model.add(keras.layers.Dense(NUM_CLASSES,
- kernel_regularizer=kernel_regularizer,
- activity_regularizer=activity_regularizer,
- input_shape=(DATA_DIM,)))
- return model
+ def get_data(self):
+ (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
+ train_samples=10,
+ test_samples=10,
+ input_shape=(DATA_DIM,),
+ num_classes=NUM_CLASSES)
+ y_train = keras.utils.to_categorical(y_train, NUM_CLASSES)
+ y_test = keras.utils.to_categorical(y_test, NUM_CLASSES)
+ return (x_train, y_train), (x_test, y_test)
-
-class KerasRegularizersTest(test.TestCase):
-
- def test_kernel_regularization(self):
+ @parameterized.named_parameters([
+ ('l1', keras.regularizers.l1()),
+ ('l2', keras.regularizers.l2()),
+ ('l1_l2', keras.regularizers.l1_l2()),
+ ])
+ def test_kernel_regularization(self, regularizer):
with self.cached_session():
- (x_train, y_train), _ = get_data()
- for reg in [keras.regularizers.l1(),
- keras.regularizers.l2(),
- keras.regularizers.l1_l2()]:
- model = create_model(kernel_regularizer=reg)
- model.compile(loss='categorical_crossentropy', optimizer='sgd')
- assert len(model.losses) == 1
- model.fit(x_train, y_train, batch_size=10,
- epochs=1, verbose=0)
+ (x_train, y_train), _ = self.get_data()
+ model = self.create_model(kernel_regularizer=regularizer)
+ model.compile(loss='categorical_crossentropy', optimizer='sgd')
+ assert len(model.losses) == 1
+ model.fit(x_train, y_train, batch_size=10,
+ epochs=1, verbose=0)
- @test_util.run_deprecated_v1
- def test_activity_regularization(self):
+ @parameterized.named_parameters([
+ ('l1', keras.regularizers.l1()),
+ ('l2', keras.regularizers.l2()),
+ ])
+ @test_util.deprecated_graph_mode_only
+ def test_activity_regularization(self, regularizer):
with self.cached_session():
- (x_train, y_train), _ = get_data()
- for reg in [keras.regularizers.l1(), keras.regularizers.l2()]:
- model = create_model(activity_regularizer=reg)
- model.compile(loss='categorical_crossentropy', optimizer='sgd')
- assert len(model.losses) == 1
- model.fit(x_train, y_train, batch_size=10,
- epochs=1, verbose=0)
+ (x_train, y_train), _ = self.get_data()
+ model = self.create_model(activity_regularizer=regularizer)
+ model.compile(loss='categorical_crossentropy', optimizer='sgd')
+ assert len(model.losses) == 1
+ model.fit(x_train, y_train, batch_size=10,
+ epochs=1, verbose=0)
def test_zero_regularization(self):
inputs = keras.backend.ones(shape=(10, 10))
diff --git a/tensorflow/python/keras/saving/saved_model.py b/tensorflow/python/keras/saving/saved_model.py
index fbf0bf6..a614359 100644
--- a/tensorflow/python/keras/saving/saved_model.py
+++ b/tensorflow/python/keras/saving/saved_model.py
@@ -37,7 +37,7 @@
from tensorflow.python.saved_model import utils_impl as saved_model_utils
from tensorflow.python.training import mode_keys
from tensorflow.python.training import saver as saver_lib
-from tensorflow.python.training.checkpointable import util as checkpointable_utils
+from tensorflow.python.training.checkpointable import graph_view
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import keras_export
@@ -52,7 +52,7 @@
`save_model` generates new files/folders under the `saved_model_path` folder:
1) a checkpoint containing the model weights.
2) a saved_model.pb file containing the model's MetaGraphs. The prediction
- graph is always exported. The evaluaton and training graphs are exported
+ graph is always exported. The evaluation and training graphs are exported
if the following conditions are met:
- Evaluation: model loss is defined.
- Training: model is compiled with an optimizer defined under `tf.train`.
@@ -220,7 +220,8 @@
def _get_var_list(model):
"""Returns list of all checkpointed saveable objects in the model."""
- return checkpointable_utils.named_saveables(model)
+ var_list, _, _ = graph_view.ObjectGraphView(model).serialize_object_graph()
+ return var_list
def create_placeholder(spec):
@@ -287,7 +288,7 @@
clone._make_predict_function()
g.get_collection_ref(ops.GraphKeys.UPDATE_OPS).extend(clone.state_updates)
- clone_var_list = checkpointable_utils.named_saveables(clone)
+ clone_var_list = _get_var_list(clone)
with session.Session().as_default():
if has_saved_vars:
diff --git a/tensorflow/python/keras/saving/saved_model_test.py b/tensorflow/python/keras/saving/saved_model_test.py
index 8063b8a..6ecb4ed 100644
--- a/tensorflow/python/keras/saving/saved_model_test.py
+++ b/tensorflow/python/keras/saving/saved_model_test.py
@@ -38,7 +38,6 @@
from tensorflow.python.platform import test
from tensorflow.python.saved_model import loader_impl
from tensorflow.python.saved_model import model_utils
-from tensorflow.python.saved_model import signature_constants
from tensorflow.python.training import mode_keys
from tensorflow.python.training import training as training_module
@@ -264,10 +263,7 @@
def load_model(sess, path, mode):
tags = model_utils.EXPORT_TAG_MAP[mode]
- if mode == mode_keys.ModeKeys.PREDICT:
- sig_def_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
- else:
- sig_def_key = mode
+ sig_def_key = model_utils.SIGNATURE_KEY_MAP[mode]
meta_graph_def = loader_impl.load(sess, tags, path)
inputs = {
diff --git a/tensorflow/python/keras/utils/data_utils.py b/tensorflow/python/keras/utils/data_utils.py
index 9b4a50d..0f6e89b 100644
--- a/tensorflow/python/keras/utils/data_utils.py
+++ b/tensorflow/python/keras/utils/data_utils.py
@@ -246,10 +246,10 @@
try:
try:
urlretrieve(origin, fpath, dl_progress)
- except URLError as e:
- raise Exception(error_msg.format(origin, e.errno, e.reason))
except HTTPError as e:
raise Exception(error_msg.format(origin, e.code, e.msg))
+ except URLError as e:
+ raise Exception(error_msg.format(origin, e.errno, e.reason))
except (Exception, KeyboardInterrupt) as e:
if os.path.exists(fpath):
os.remove(fpath)
diff --git a/tensorflow/python/keras/utils/losses_utils.py b/tensorflow/python/keras/utils/losses_utils.py
index d42b354..899780e 100644
--- a/tensorflow/python/keras/utils/losses_utils.py
+++ b/tensorflow/python/keras/utils/losses_utils.py
@@ -18,6 +18,7 @@
from __future__ import division
from __future__ import print_function
+from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.keras import backend as K
@@ -148,7 +149,9 @@
else:
loss = math_ops.reduce_sum(weighted_losses)
if reduction == losses_impl.ReductionV2.SUM_OVER_BATCH_SIZE:
- loss = _safe_mean(loss, _num_elements(weighted_losses))
+ num_replicas = ( # Used to convert from local to global batch size.
+ distribution_strategy_context.get_strategy().num_replicas_in_sync)
+ loss = _safe_mean(loss, num_replicas * _num_elements(weighted_losses))
return loss
@@ -177,11 +180,6 @@
if sample_weight is None:
sample_weight = 1.0
with ops.name_scope(name, 'weighted_loss', (losses, sample_weight)):
- # Save the `reduction` argument for loss normalization when distributing
- # to multiple replicas.
- # TODO(josh11b): Associate it with the returned op for more precision.
- ops.get_default_graph()._last_loss_reduction = reduction # pylint: disable=protected-access
-
# Update dimensions of `sample_weight` to match with `losses` if possible.
losses, _, sample_weight = squeeze_or_expand_dimensions(
losses, None, sample_weight)
diff --git a/tensorflow/python/kernel_tests/benchmark_test.py b/tensorflow/python/kernel_tests/benchmark_test.py
index a91f96c..3fa2054 100644
--- a/tensorflow/python/kernel_tests/benchmark_test.py
+++ b/tensorflow/python/kernel_tests/benchmark_test.py
@@ -126,7 +126,7 @@
self.assertFalse(_ran_somebenchmark_2[0])
self.assertFalse(_ran_somebenchmark_but_shouldnt[0])
- @test_util.disable_xla("This test never passed for XLA")
+ @test_util.disable_xla("b/123744455") # GPU memory is incorrect
def testReportingBenchmark(self):
tempdir = test.get_temp_dir()
try:
diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py
index d5f3696..7d00919 100644
--- a/tensorflow/python/kernel_tests/check_ops_test.py
+++ b/tensorflow/python/kernel_tests/check_ops_test.py
@@ -889,8 +889,8 @@
# Dynamic shape check
@test_util.run_deprecated_v1
- @test_util.disable_xla("This test never passed for XLA"
- ) # Dynamic shapes not supported now with XLA
+ @test_util.disable_xla(
+ "b/123337890") # Dynamic shapes not supported now with XLA
def testEnsuresDynamicShape_RaisesError(self):
placeholder = array_ops.placeholder(dtypes.int32)
derived = math_ops.divide(placeholder, 3, name="MyDivide")
@@ -904,8 +904,8 @@
sess.run(derived, feed_dict={placeholder: feed_val})
@test_util.run_deprecated_v1
- @test_util.disable_xla("This test never passed for XLA"
- ) # Dynamic shapes not supported now with XLA
+ @test_util.disable_xla(
+ "b/123337890") # Dynamic shapes not supported now with XLA
def testEnsuresDynamicShape_RaisesErrorDimUnknown(self):
placeholder = array_ops.placeholder(dtypes.int32)
derived = placeholder / 3
diff --git a/tensorflow/python/kernel_tests/cholesky_op_test.py b/tensorflow/python/kernel_tests/cholesky_op_test.py
index abb71a6..2305c0b 100644
--- a/tensorflow/python/kernel_tests/cholesky_op_test.py
+++ b/tensorflow/python/kernel_tests/cholesky_op_test.py
@@ -163,7 +163,9 @@
with self.assertRaises(ValueError):
linalg_ops.cholesky(tensor3)
- @test_util.disable_xla("This test never passed for XLA") # all nan on XLA
+ # The below invalid Cholesky call returns an error with TF Classic and just
+ # returns NaNs with XLA.
+ @test_util.disable_xla("b/123337890")
def testNotInvertibleCPU(self):
# The input should be invertible.
with self.session(use_gpu=True):
diff --git a/tensorflow/python/kernel_tests/concat_op_test.py b/tensorflow/python/kernel_tests/concat_op_test.py
index a968b06..7e37785 100644
--- a/tensorflow/python/kernel_tests/concat_op_test.py
+++ b/tensorflow/python/kernel_tests/concat_op_test.py
@@ -33,7 +33,6 @@
from tensorflow.python.platform import test
-@test_util.disable_all_xla("This test never passed for XLA")
class ConcatOpTest(test.TestCase):
@test_util.run_deprecated_v1
@@ -642,7 +641,6 @@
self.assertAllEqual([[1, 2, 3, 7, 8, 9], [4, 5, 6, 10, 11, 12]], output)
-@test_util.disable_all_xla("This test never passed for XLA")
class ConcatOffsetTest(test.TestCase):
def testBasic(self):
@@ -686,8 +684,7 @@
self.evaluate(off)
@test_util.run_deprecated_v1
- @test_util.disable_xla(
- "This test never passed for XLA") # Different error message on XLA
+ @test_util.disable_xla("b/123337890") # Error messages differ
def testSizeMismatch(self):
cdim = constant_op.constant(1, dtypes.int32)
s0 = constant_op.constant([2, 3, 5], dtypes.int32)
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index c5b7a95..a073bb3 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -1055,7 +1055,8 @@
with self.captureWritesToStream(sys.stderr) as printed:
self.assertEqual(self.evaluate(cond()), 10)
- self.assertEqual(printed.contents(), "A\nB\nC\n")
+ self.assertTrue(printed.contents().endswith("A\nB\nC\n"),
+ printed.contents())
@eager_function.defun
def nested_cond():
@@ -1063,7 +1064,8 @@
with self.captureWritesToStream(sys.stderr) as printed:
self.assertEqual(self.evaluate(nested_cond()), 10)
- self.assertEqual(printed.contents(), "A\nB\nC\n")
+ self.assertTrue(printed.contents().endswith("A\nB\nC\n"),
+ printed.contents())
# wrap_function should prune.
def pruned_cond():
@@ -1112,11 +1114,13 @@
with self.cached_session():
with self.captureWritesToStream(sys.stderr) as printed:
self.assertEqual(self.evaluate(build_while()[0]), 2)
- self.assertEqual(printed.contents(), "D\nD\n")
+ self.assertTrue(printed.contents().endswith("D\nD\n"),
+ printed.contents())
with self.captureWritesToStream(sys.stderr) as printed:
self.assertEqual(self.evaluate(build_nested_while()[0]), 2)
- self.assertEqual(printed.contents(), "D\nD\n")
+ self.assertTrue(printed.contents().endswith("D\nD\n"),
+ printed.contents())
# In defuns, all prints should execute in program order.
@eager_function.defun
@@ -1125,7 +1129,8 @@
with self.captureWritesToStream(sys.stderr) as printed:
self.assertEqual(self.evaluate(while_loop()), 2)
- self.assertEqual(printed.contents(), "A\nB\nC\nD\nA\nB\nC\nD\nA\n")
+ self.assertTrue(printed.contents().endswith("A\nB\nC\nD\nA\nB\nC\nD\nA\n"),
+ printed.contents())
@eager_function.defun
def nested_while_loop():
@@ -1135,7 +1140,9 @@
if not context.executing_eagerly():
with self.captureWritesToStream(sys.stderr) as printed:
self.assertEqual(self.evaluate(nested_while_loop()), 2)
- self.assertEqual(printed.contents(), "A\nB\nC\nD\nA\nB\nC\nD\nA\n")
+ self.assertTrue(
+ printed.contents().endswith("A\nB\nC\nD\nA\nB\nC\nD\nA\n"),
+ printed.contents())
# wrap_function should prune.
def pruned_while():
@@ -1144,7 +1151,7 @@
with self.captureWritesToStream(sys.stderr) as printed:
self.assertEqual(self.evaluate(pruned_while()), 2)
- self.assertEqual(printed.contents(), "D\nD\n")
+ self.assertTrue(printed.contents().endswith("D\nD\n"), printed.contents())
def pruned_nested_while():
return build_nested_while()[0]
@@ -1154,7 +1161,7 @@
if not context.executing_eagerly():
with self.captureWritesToStream(sys.stderr) as printed:
self.assertEqual(self.evaluate(pruned_nested_while()), 2)
- self.assertEqual(printed.contents(), "D\nD\n")
+ self.assertTrue(printed.contents().endswith("D\nD\n"), printed.contents())
# Microbenchmark: 256,000 iterations/s.
def testWhile_1(self):
@@ -1358,7 +1365,7 @@
r"while loop context '' \(currently defined in 'cond/.+'\)"):
_ = gradients_impl.gradients(loop, v)
- @test_util.disable_control_flow_v2("b/118457764")
+ @test_util.disable_control_flow_v2("b/123601232")
@test_util.run_v1_only("b/120545219")
def testNestedWhileLoopWithMaxItersFromOuterContextInXLAContext(self):
v = constant_op.constant(1.0)
@@ -2518,6 +2525,178 @@
self.evaluate(variables.global_variables_initializer())
self.assertAllClose(216.0, g[0])
+ def testWhileGrad_ResourceVarInFunctionCall(self):
+
+ @def_function.function
+ def foo(x, var):
+ return x + math_ops.reduce_sum(var.sparse_read([1, 3]))
+
+ @def_function.function
+ def bar(var):
+ r = control_flow_ops.while_loop(
+ lambda i, _: i < 2,
+ lambda i, x: (i + 1, foo(x, var)),
+ [0, 0.0])[1]
+ return gradients_impl.gradients(r, var)[0]
+
+ var = resource_variable_ops.ResourceVariable([1., 2., 3., 4.])
+ self.evaluate(variables.global_variables_initializer())
+ grad = self.evaluate(bar(var))
+ self.assertIsInstance(grad, ops.IndexedSlicesValue)
+ self.assertAllEqual(gradient_checker_v2._to_numpy(grad), [0., 2., 0., 2.])
+
+ def testWhileGrad_ResourceVarInNestedFunctionCall(self):
+
+ @def_function.function
+ def foo(x, var):
+ return x + math_ops.reduce_sum(var.sparse_read([1, 3]))
+
+ @def_function.function
+ def foo2(x, var):
+ return foo(x, var)
+
+ @def_function.function
+ def bar(var):
+ r = control_flow_ops.while_loop(
+ lambda i, _: i < 2,
+ lambda i, x: (i + 1, foo2(x, var)),
+ [0, 0.0])[1]
+ return gradients_impl.gradients(r, var)[0]
+
+ var = resource_variable_ops.ResourceVariable([1., 1., 1., 1.])
+ self.evaluate(variables.global_variables_initializer())
+ grad = self.evaluate(bar(var))
+ self.assertIsInstance(grad, ops.IndexedSlicesValue)
+ self.assertAllEqual(gradient_checker_v2._to_numpy(grad), [0., 2., 0., 2.])
+
+ def testWhileGrad_ResourceVarInLoopInFunctionCall(self):
+
+ @def_function.function
+ def foo(x, var):
+ return control_flow_ops.while_loop(
+ lambda j, _: j < 3,
+ lambda j, y: (j + 1,
+ y + math_ops.reduce_sum(var.sparse_read([1, 2]))),
+ [0, x])[1]
+
+ @def_function.function
+ def bar(var):
+ r = control_flow_ops.while_loop(
+ lambda i, _: i < 2,
+ lambda i, x: (i + 1, foo(x, var)),
+ [0, 0.0])[1]
+ return gradients_impl.gradients(r, var)[0]
+
+ var = resource_variable_ops.ResourceVariable([1., 1., 1., 1.])
+ self.evaluate(variables.global_variables_initializer())
+ grad = self.evaluate(bar(var))
+ self.assertIsInstance(grad, ops.IndexedSlicesValue)
+ self.assertAllEqual(gradient_checker_v2._to_numpy(grad), [0., 6., 6., 0.])
+
+ def testWhileCondGrad_ResourceVarInFunctionCall(self):
+
+ @def_function.function
+ def foo(x, var):
+ return x + var.sparse_read([1])[0]
+
+ def body(i, x):
+ return (i + 1, control_flow_ops.cond(
+ math_ops.equal(i % 2, 0),
+ lambda: foo(x, var1),
+ lambda: foo(x, var2)))
+
+ @def_function.function
+ def bar(var1, var2):
+ r = control_flow_ops.while_loop(
+ lambda i, _: i < 4, body, [0, 0.0])
+ return gradients_impl.gradients(r, [var1, var2])
+
+ var1 = resource_variable_ops.ResourceVariable([1., 2., 3.])
+ var2 = resource_variable_ops.ResourceVariable([4., 5.])
+ self.evaluate(variables.global_variables_initializer())
+ grads = self.evaluate(bar(var1, var2))
+ self.assertAllEqual(gradient_checker_v2._to_numpy(grads[0]), [0., 2., 0.])
+ self.assertAllEqual(gradient_checker_v2._to_numpy(grads[1]), [0., 2.])
+
+ @test_util.run_deprecated_v1
+ def testWhileGrad_ResourceVarSparseRead(self):
+ # NOTE(skyewm): this test is interesting because the
+ # ResourceVariable.sparse_read gradient function returns an IndexedSlices.
+ var = resource_variable_ops.ResourceVariable(np.ones(5),
+ dtype=dtypes.float32)
+ r = control_flow_ops.while_loop(
+ lambda i, _: i < 3,
+ lambda i, x: (i + 1, x * math_ops.reduce_sum(var.sparse_read([1, 3]))),
+ [0, constant_op.constant(1.0)])[1]
+ grad = gradients_impl.gradients(r, var)[0]
+
+ self.evaluate(variables.global_variables_initializer())
+ grad_val = self.evaluate(grad)
+ self.assertIsInstance(grad_val, ops.IndexedSlicesValue)
+ arr = gradient_checker_v2._to_numpy(grad_val)
+ self.assertAllEqual(arr, [0., 12., 0., 12., 0.])
+
+ @test_util.run_deprecated_v1
+ def testWhileGrad_MultiResourceVarSparseRead(self):
+ # NOTE(skyewm): this test is interesting because the
+ # ResourceVariable.sparse_read gradient function returns an IndexedSlices.
+ var1 = resource_variable_ops.ResourceVariable(np.ones(5),
+ dtype=dtypes.float32)
+ var2 = resource_variable_ops.ResourceVariable(np.ones(3),
+ dtype=dtypes.float32)
+ x1_init = constant_op.constant([0., 0.])
+ x2_init = constant_op.constant(1.)
+ x3_init = constant_op.constant(1.)
+
+ def body(i, unused_x1, x2, x3):
+ y1 = var1.sparse_read([1, 3])
+ y2 = x2 * 2
+ y3 = x3 * math_ops.reduce_sum(var2.sparse_read([0]))
+ return i + 1, y1, y2, y3
+
+ r = control_flow_ops.while_loop(
+ lambda i, x1, x2, x3: i < 3, body,
+ [0, x1_init, x2_init, x3_init])[1:]
+ var1_grad, var2_grad = gradients_impl.gradients(r, [var1, var2])
+
+ self.evaluate(variables.global_variables_initializer())
+ var1_grad_val = self.evaluate(var1_grad)
+ var2_grad_val = self.evaluate(var2_grad)
+ self.assertIsInstance(var1_grad_val, ops.IndexedSlicesValue)
+ self.assertIsInstance(var2_grad_val, ops.IndexedSlicesValue)
+ self.assertAllEqual(gradient_checker_v2._to_numpy(var1_grad_val),
+ [0., 1., 0., 1., 0.])
+ self.assertAllEqual(gradient_checker_v2._to_numpy(var2_grad_val),
+ [3., 0., 0.])
+
+ @test_util.run_deprecated_v1
+ def testWhileGrad_Gather(self):
+ # NOTE(skyewm): this test is interesting because the gather gradient
+ # function returns an IndexedSlices.
+ x = constant_op.constant([1., 1., 1., 1., 1.])
+ y = control_flow_ops.while_loop(
+ lambda i, _: i < 3,
+ lambda i, x: (i + 1, x + array_ops.gather(x, [0])),
+ [0, x[:1]])[1]
+ z = y * 3.0
+ grad = gradients_impl.gradients(z, x)[0]
+ self.assertEqual(self.evaluate(y), 8.)
+ self.assertAllEqual(self.evaluate(grad), [24., 0., 0., 0., 0.])
+
+ @test_util.run_deprecated_v1
+ def testWhileGrad_GatherNoFanOut(self):
+ # NOTE(skyewm): this test is interesting because the gather gradient
+ # function returns an IndexedSlices.
+ x = constant_op.constant([1., 1., 1., 1., 1.])
+ y = control_flow_ops.while_loop(
+ lambda i, _: i < 3,
+ lambda i, x: (i + 1, array_ops.gather(x, [0])),
+ [0, x[:1]])[1]
+ z = y * 3.0
+ grad = gradients_impl.gradients(z, x)[0]
+ self.assertEqual(self.evaluate(y), 1.)
+ self.assertAllEqual(self.evaluate(grad), [3., 0., 0., 0., 0.])
+
@test_util.run_v1_only("b/120545219")
def testWhileGradInCond(self):
@@ -2590,8 +2769,6 @@
self.assertEqual(i_val, 3)
self.assertAllClose(x_val, 1.0)
- @test_util.disable_xla("This test never passed for XLA"
- ) # Resource variable issue for ControlFlowV2
@test_util.run_gpu_only
def testGpuResourceAccess(self):
with ops.device(test.gpu_device_name()):
@@ -3134,25 +3311,24 @@
def testNestedWhileAndTensorArray(self):
n = constant_op.constant(3.0)
- def Body(row, ta, n):
+ def Body(row, ta):
- def InnerBody(row, col, ta, n):
+ def InnerBody(row, col, ta):
# Note: row and col are 1-based.
ta = ta.write(
math_ops.cast(n * (row - 1.) + col - 1., dtypes.int32), row * col)
- return row, col + 1., ta, n
+ return row, col + 1., ta
- # TODO(b/118457764): Remove n from loop_vars from both loops once fixed.
ta = control_flow_ops.while_loop(
- lambda _, col, _1, n: col <= n,
- InnerBody, [row, constant_op.constant(1.), ta, n],
+ lambda _, col, _1: col <= n,
+ InnerBody, [row, constant_op.constant(1.), ta],
return_same_structure=False)[2]
- return row + 1., ta, n
+ return row + 1., ta
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=9)
ta = control_flow_ops.while_loop(
- lambda row, _, _1: row <= n,
- Body, [constant_op.constant(1.), ta, n],
+ lambda row, _: row <= n,
+ Body, [constant_op.constant(1.), ta],
return_same_structure=False)[1]
output = array_ops.reshape(ta.stack(), [3, 3])
diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py
index d9b908b..732d870 100644
--- a/tensorflow/python/kernel_tests/conv_ops_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_test.py
@@ -574,7 +574,6 @@
padding="VALID")
@test_util.run_in_graph_and_eager_modes()
- @test_util.disable_xla("This test never passed for XLA")
def testConv2D0x0Padding(self):
self._VerifyExplicitPaddings(
tensor_in_sizes=[1, 2, 3, 3],
@@ -589,7 +588,6 @@
padding=[[0, 0], [0, 0]])
@test_util.run_in_graph_and_eager_modes()
- @test_util.disable_xla("This test never passed for XLA")
def testConv2D1x1Padding(self):
self._VerifyExplicitPaddings(
tensor_in_sizes=[1, 2, 3, 2],
@@ -604,7 +602,6 @@
padding=[[1, 1], [1, 1]])
@test_util.run_in_graph_and_eager_modes()
- @test_util.disable_xla("This test never passed for XLA")
def testConv2D2x2Padding(self):
self._VerifyExplicitPaddings(
tensor_in_sizes=[1, 2, 1, 2],
@@ -619,7 +616,6 @@
padding=[[2, 2], [2, 2]])
@test_util.run_in_graph_and_eager_modes()
- @test_util.disable_xla("This test never passed for XLA")
def testConv2DOnlyBottomPadding(self):
self._VerifyExplicitPaddings(
tensor_in_sizes=[1, 2, 3, 3],
@@ -634,7 +630,6 @@
padding=[[0, 3], [0, 0]])
@test_util.run_in_graph_and_eager_modes()
- @test_util.disable_xla("This test never passed for XLA")
def testConv2DOnlyTopRightPadding(self):
self._VerifyExplicitPaddings(
tensor_in_sizes=[1, 2, 3, 3],
@@ -650,7 +645,6 @@
padding=[[1, 0], [0, 2]])
@test_util.run_in_graph_and_eager_modes()
- @test_util.disable_xla("This test never passed for XLA")
def testConv2DLotsPadding(self):
self._VerifyExplicitPaddings(
tensor_in_sizes=[1, 1, 1, 3],
@@ -665,7 +659,6 @@
padding=[[3, 4], [4, 2]])
@test_util.run_in_graph_and_eager_modes()
- @test_util.disable_xla("This test never passed for XLA")
def testConv2DExplicitPaddingWithDilations(self):
self._VerifyExplicitPaddings(
tensor_in_sizes=[1, 3, 2, 1],
@@ -681,7 +674,6 @@
padding=[[2, 1], [1, 2]],
dilations=[2, 3])
- @test_util.disable_xla("This test never passed for XLA")
def testConv2DExplicitPaddingWithLayoutOptimizer(self):
# Test with Grappler's layout optimizer, to ensure the layout optimizer
# handles explicit padding correctly.
@@ -1349,7 +1341,6 @@
dilations=dilations)
@test_util.run_in_graph_and_eager_modes()
- @test_util.disable_xla("This test never passed for XLA")
def testConv2D2x2Depth1Padding0x0BackpropInput(self):
if not test.is_gpu_available(cuda_only=True):
return
@@ -1372,7 +1363,6 @@
data_format=data_format)
@test_util.run_in_graph_and_eager_modes()
- @test_util.disable_xla("This test never passed for XLA")
def testConv2D2x2Depth1Padding1x1BackpropInput(self):
if not test.is_gpu_available(cuda_only=True):
return
@@ -1405,7 +1395,6 @@
dilations=[2, 2])
@test_util.run_in_graph_and_eager_modes()
- @test_util.disable_xla("This test never passed for XLA")
def testConv2D2x2Depth1Padding2x2BackpropInput(self):
if not test.is_gpu_available(cuda_only=True):
return
@@ -1430,7 +1419,6 @@
dilations=[2, 3])
@test_util.run_in_graph_and_eager_modes()
- @test_util.disable_xla("This test never passed for XLA")
def testConv2D2x2Depth1Padding_1_8_4_1_BackpropInput(self):
if not test.is_gpu_available(cuda_only=True):
return
@@ -1453,7 +1441,6 @@
data_format=data_format)
@test_util.run_in_graph_and_eager_modes()
- @test_util.disable_xla("This test never passed for XLA")
def testConv2D2x2Depth1Padding_5_0_2_2_BackpropInput(self):
if not test.is_gpu_available(cuda_only=True):
return
@@ -1512,7 +1499,6 @@
err=err)
@test_util.run_in_graph_and_eager_modes()
- @test_util.disable_xla("This test never passed for XLA")
def testConv2D2x2Depth1Padding0x0BackpropFilter(self):
if not test.is_gpu_available(cuda_only=True):
return
@@ -1535,7 +1521,6 @@
data_format=data_format)
@test_util.run_in_graph_and_eager_modes()
- @test_util.disable_xla("This test never passed for XLA")
def testConv2D2x2Depth1Padding1x1BackpropFilter(self):
if not test.is_gpu_available(cuda_only=True):
return
@@ -1569,7 +1554,6 @@
dilations=[2, 2])
@test_util.run_in_graph_and_eager_modes()
- @test_util.disable_xla("This test never passed for XLA")
def testConv2D2x2Depth1Padding2x2BackpropFilter(self):
if not test.is_gpu_available(cuda_only=True):
return
@@ -1594,7 +1578,6 @@
dilations=[2, 3])
@test_util.run_in_graph_and_eager_modes()
- @test_util.disable_xla("This test never passed for XLA")
def testConv2D2x2Depth1Padding_1_8_4_1_BackpropFilter(self):
if not test.is_gpu_available(cuda_only=True):
return
@@ -1618,7 +1601,6 @@
data_format=data_format)
@test_util.run_in_graph_and_eager_modes()
- @test_util.disable_xla("This test never passed for XLA")
def testConv2D2x2Depth1Padding_5_0_2_2_BackpropFilter(self):
if not test.is_gpu_available(cuda_only=True):
return
@@ -1976,7 +1958,6 @@
data_format=data_format,
use_gpu=use_gpu)
- @test_util.disable_xla("This test never passed for XLA")
def testInputGradient1x1PaddingStrideOne(self):
if not test.is_gpu_available(cuda_only=True):
return
@@ -1998,7 +1979,6 @@
use_gpu=use_gpu,
max_err=0.0025)
- @test_util.disable_xla("This test never passed for XLA")
def testFilterGradient1x1PaddingStrideOne(self):
if not test.is_gpu_available(cuda_only=True):
return
@@ -2019,7 +1999,6 @@
data_format=data_format,
use_gpu=use_gpu)
- @test_util.disable_xla("This test never passed for XLA")
def testInputGradient1x1PaddingStrideTwo(self):
if not test.is_gpu_available(cuda_only=True):
return
@@ -2040,7 +2019,6 @@
data_format=data_format,
use_gpu=use_gpu)
- @test_util.disable_xla("This test never passed for XLA")
def testFilterGradient1x1PaddingStrideTwo(self):
if not test.is_gpu_available(cuda_only=True):
return
@@ -2061,7 +2039,6 @@
data_format=data_format,
use_gpu=use_gpu)
- @test_util.disable_xla("This test never passed for XLA")
def testInputGradient2x2PaddingStrideOne(self):
if not test.is_gpu_available(cuda_only=True):
return
@@ -2082,7 +2059,6 @@
data_format=data_format,
use_gpu=use_gpu)
- @test_util.disable_xla("This test never passed for XLA")
def testFilterGradient2x2PaddingStrideOne(self):
if not test.is_gpu_available(cuda_only=True):
return
@@ -2104,7 +2080,6 @@
use_gpu=use_gpu,
max_err=0.003)
- @test_util.disable_xla("This test never passed for XLA")
def testInputGradient1_2_3_4PaddingStride3x2(self):
if not test.is_gpu_available(cuda_only=True):
return
@@ -2125,7 +2100,6 @@
data_format=data_format,
use_gpu=use_gpu)
- @test_util.disable_xla("This test never passed for XLA")
def testFilterGradient1_2_3_4PaddingStride3x2(self):
if not test.is_gpu_available(cuda_only=True):
return
@@ -2146,7 +2120,6 @@
data_format=data_format,
use_gpu=use_gpu)
- @test_util.disable_xla("This test never passed for XLA")
def testInputGradient4_3_2_1PaddingStride2x1(self):
if not test.is_gpu_available(cuda_only=True):
return
@@ -2167,7 +2140,6 @@
data_format=data_format,
use_gpu=use_gpu)
- @test_util.disable_xla("This test never passed for XLA")
def testFilterGradient4_3_2_1PaddingStride2x1(self):
if not test.is_gpu_available(cuda_only=True):
return
@@ -2188,7 +2160,6 @@
data_format=data_format,
use_gpu=use_gpu)
- @test_util.disable_xla("This test never passed for XLA")
def testInputGradient0_0_0_5PaddingStride1x2(self):
if not test.is_gpu_available(cuda_only=True):
return
@@ -2209,7 +2180,6 @@
data_format=data_format,
use_gpu=use_gpu)
- @test_util.disable_xla("This test never passed for XLA")
def testFilterGradient0_0_0_5PaddingStride1x2(self):
if not test.is_gpu_available(cuda_only=True):
return
@@ -2316,7 +2286,7 @@
strides=[1, 1, 1, 1],
padding=[0, 0, 0, 0])
- @test_util.disable_xla("This test never passed for XLA")
+ @test_util.disable_xla("b/123337890") # Error messages differ
def testOpEdgeCases(self):
with self.cached_session() as sess:
# Illegal strides.
diff --git a/tensorflow/python/kernel_tests/depthtospace_op_test.py b/tensorflow/python/kernel_tests/depthtospace_op_test.py
index 97d3645..b7a865c 100644
--- a/tensorflow/python/kernel_tests/depthtospace_op_test.py
+++ b/tensorflow/python/kernel_tests/depthtospace_op_test.py
@@ -295,7 +295,7 @@
actual_vals, expected_vals = self.evaluate([actual, expected])
self.assertTrue(np.array_equal(actual_vals, expected_vals))
- @test_util.disable_xla("This test never passed for XLA")
+ @test_util.disable_xla("b/123553551") # Unsupported data format
def testAgainstTranspose(self):
self.compareToTranspose(3, 2, 3, 1, 2, "NHWC", False)
self.compareToTranspose(3, 2, 3, 2, 2, "NHWC", False)
diff --git a/tensorflow/python/kernel_tests/diag_op_test.py b/tensorflow/python/kernel_tests/diag_op_test.py
index b813991..0bf48fd 100644
--- a/tensorflow/python/kernel_tests/diag_op_test.py
+++ b/tensorflow/python/kernel_tests/diag_op_test.py
@@ -65,7 +65,7 @@
array_ops.matrix_diag(0)
@test_util.run_deprecated_v1
- @test_util.disable_xla("This test never passed for XLA")
+ @test_util.disable_xla("b/123337890") # Error messages differ
def testInvalidShapeAtEval(self):
with self.session(use_gpu=True):
v = array_ops.placeholder(dtype=dtypes_lib.float32)
@@ -270,7 +270,7 @@
array_ops.matrix_diag_part(0)
@test_util.run_deprecated_v1
- @test_util.disable_xla("This test never passed for XLA")
+ @test_util.disable_xla("b/123337890") # Error messages differ
def testInvalidShapeAtEval(self):
with self.session(use_gpu=True):
v = array_ops.placeholder(dtype=dtypes_lib.float32)
diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py
index 13f9acb..48e14d0 100644
--- a/tensorflow/python/kernel_tests/functional_ops_test.py
+++ b/tensorflow/python/kernel_tests/functional_ops_test.py
@@ -56,7 +56,6 @@
@test_util.with_control_flow_v2
-@test_util.disable_all_xla("This test never passed for XLA")
class FunctionalOpsTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
@@ -660,8 +659,7 @@
self.assertAllEqual(Run(100., True), 5050.)
@test_util.run_v1_only("b/120545219")
- @test_util.disable_xla(
- "This test never passed for XLA") # Different error message
+ @test_util.disable_xla("b/123337890") # Different error message
def testWhileError(self):
for use_gpu in (True, False):
with ops.Graph().as_default() as g:
@@ -938,7 +936,6 @@
# TODO(akshayka): Replace `function.Defun` with tf.contrib.eager.defun` in the
# below test cases.
-@test_util.disable_all_xla("This test never passed for XLA")
class PartitionedCallTest(test.TestCase):
@test_util.run_deprecated_v1
diff --git a/tensorflow/python/kernel_tests/gather_nd_op_test.py b/tensorflow/python/kernel_tests/gather_nd_op_test.py
index 76ae2fc..ad8376b 100644
--- a/tensorflow/python/kernel_tests/gather_nd_op_test.py
+++ b/tensorflow/python/kernel_tests/gather_nd_op_test.py
@@ -34,7 +34,6 @@
from tensorflow.python.platform import test
-@test_util.disable_all_xla("This test never passed for XLA")
class GatherNdTest(test.TestCase):
def _testSimpleDtype(self, dtype):
@@ -57,7 +56,7 @@
self._testSimpleDtype("|S") # byte strings in python2 + 3
@test_util.run_deprecated_v1
- @test_util.disable_xla("This test never passed for XLA")
+ @test_util.disable_xla("b/123337890") # Error messages differ
def testEmptyIndicesAndParamsOKButJustEmptyParamsFails(self):
with self.session(use_gpu=True):
params = np.ones((3, 3), dtype=np.float32)
@@ -360,7 +359,6 @@
self.assertAllEqual(expected_grads, ops.convert_to_tensor(grads).eval())
-@test_util.disable_all_xla("This test never passed for XLA")
class GatherNdOpBenchmark(test.Benchmark):
def benchmark_gather_nd_op(self):
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_adjoint_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_adjoint_test.py
index 1bed4b5..f70d8c4 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_adjoint_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_adjoint_test.py
@@ -114,5 +114,29 @@
self.assertEqual("my_operator_adjoint", operator.name)
+class LinearOperatorAdjointNonSquareTest(
+ linear_operator_test_util.NonSquareLinearOperatorDerivedClassTest):
+ """Tests done in the base class NonSquareLinearOperatorDerivedClassTest."""
+
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
+ shape_before_adjoint = list(build_info.shape)
+ # We need to swap the last two dimensions because we are taking the adjoint
+ # of this operator
+ shape_before_adjoint[-1], shape_before_adjoint[-2] = (
+ shape_before_adjoint[-2], shape_before_adjoint[-1])
+ matrix = linear_operator_test_util.random_normal(
+ shape_before_adjoint, dtype=dtype)
+
+ lin_op_matrix = matrix
+
+ if use_placeholder:
+ lin_op_matrix = array_ops.placeholder_with_default(matrix, shape=None)
+
+ operator = LinearOperatorAdjoint(
+ linalg.LinearOperatorFullMatrix(lin_op_matrix))
+
+ return operator, linalg.adjoint(matrix)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_algebra_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_algebra_test.py
index ec78a3f..12da865 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_algebra_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_algebra_test.py
@@ -26,15 +26,59 @@
from tensorflow.python.platform import test
# pylint: disable=protected-access
+_ADJOINTS = linear_operator_algebra._ADJOINTS
+_registered_adjoint = linear_operator_algebra._registered_adjoint
_CHOLESKY_DECOMPS = linear_operator_algebra._CHOLESKY_DECOMPS
-_MATMUL = linear_operator_algebra._MATMUL
_registered_cholesky = linear_operator_algebra._registered_cholesky
-_registered_matmul = linear_operator_algebra._registered_matmul
_INVERSES = linear_operator_algebra._INVERSES
_registered_inverse = linear_operator_algebra._registered_inverse
+_MATMUL = linear_operator_algebra._MATMUL
+_registered_matmul = linear_operator_algebra._registered_matmul
# pylint: enable=protected-access
+class AdjointTest(test.TestCase):
+
+ def testRegistration(self):
+
+ class CustomLinOp(linear_operator.LinearOperator):
+
+ def _matmul(self, a):
+ pass
+
+ def _shape(self):
+ return tensor_shape.TensorShape([1, 1])
+
+ def _shape_tensor(self):
+ pass
+
+ # Register Adjoint to a lambda that spits out the name parameter
+ @linear_operator_algebra.RegisterAdjoint(CustomLinOp)
+ def _adjoint(a): # pylint: disable=unused-argument,unused-variable
+ return "OK"
+
+ self.assertEqual("OK", CustomLinOp(dtype=None).adjoint())
+
+ def testRegistrationFailures(self):
+
+ class CustomLinOp(linear_operator.LinearOperator):
+ pass
+
+ with self.assertRaisesRegexp(TypeError, "must be callable"):
+ linear_operator_algebra.RegisterAdjoint(CustomLinOp)("blah")
+
+ # First registration is OK
+ linear_operator_algebra.RegisterAdjoint(CustomLinOp)(lambda a: None)
+
+ # Second registration fails
+ with self.assertRaisesRegexp(ValueError, "has already been registered"):
+ linear_operator_algebra.RegisterAdjoint(CustomLinOp)(lambda a: None)
+
+ def testExactAdjointRegistrationsAllMatch(self):
+ for (k, v) in _ADJOINTS.items():
+ self.assertEqual(v, _registered_adjoint(k[0]))
+
+
class CholeskyTest(test.TestCase):
def testRegistration(self):
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_block_diag_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_block_diag_test.py
index 96e6e3c..28f8d20 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_block_diag_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_block_diag_test.py
@@ -136,6 +136,27 @@
self.assertTrue(operator.is_non_singular)
self.assertFalse(operator.is_self_adjoint)
+ def test_block_diag_adjoint_type(self):
+ matrix = [[1., 0.], [0., 1.]]
+ operator = block_diag.LinearOperatorBlockDiag(
+ [
+ linalg.LinearOperatorFullMatrix(
+ matrix,
+ is_non_singular=True,
+ ),
+ linalg.LinearOperatorFullMatrix(
+ matrix,
+ is_non_singular=True,
+ ),
+ ],
+ is_non_singular=True,
+ )
+ adjoint = operator.adjoint()
+ self.assertIsInstance(
+ adjoint,
+ block_diag.LinearOperatorBlockDiag)
+ self.assertEqual(2, len(adjoint.operators))
+
def test_block_diag_cholesky_type(self):
matrix = [[1., 0.], [0., 1.]]
operator = block_diag.LinearOperatorBlockDiag(
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py
index 4d7a31b..5c3220e 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py
@@ -187,6 +187,11 @@
linalg_lib.LinearOperatorDiag))
self.assertAllClose([6., 9.], self.evaluate(operator_matmul.diag))
+ def test_diag_adjoint_type(self):
+ diag = [1., 3., 5., 8.]
+ operator = linalg.LinearOperatorDiag(diag, is_non_singular=True)
+ self.assertIsInstance(operator.adjoint(), linalg.LinearOperatorDiag)
+
def test_diag_cholesky_type(self):
diag = [1., 3., 5., 8.]
operator = linalg.LinearOperatorDiag(
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py
index ea9ee99..55eff59 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py
@@ -259,6 +259,12 @@
is_non_singular=None,
)
+ def test_identity_adjoint_type(self):
+ operator = linalg_lib.LinearOperatorIdentity(
+ num_rows=2, is_non_singular=True)
+ self.assertIsInstance(
+ operator.adjoint(), linalg_lib.LinearOperatorIdentity)
+
def test_identity_cholesky_type(self):
operator = linalg_lib.LinearOperatorIdentity(
num_rows=2,
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py
index 54ccc0c..166188f 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py
@@ -192,6 +192,23 @@
with self.assertRaisesRegexp(ValueError, ">=1 operators"):
kronecker.LinearOperatorKronecker([])
+ def test_kronecker_adjoint_type(self):
+ matrix = [[1., 0.], [0., 1.]]
+ operator = kronecker.LinearOperatorKronecker(
+ [
+ linalg.LinearOperatorFullMatrix(
+ matrix, is_non_singular=True),
+ linalg.LinearOperatorFullMatrix(
+ matrix, is_non_singular=True),
+ ],
+ is_non_singular=True,
+ )
+ adjoint = operator.adjoint()
+ self.assertIsInstance(
+ adjoint,
+ kronecker.LinearOperatorKronecker)
+ self.assertEqual(2, len(adjoint.operators))
+
def test_kronecker_cholesky_type(self):
matrix = [[1., 0.], [0., 1.]]
operator = kronecker.LinearOperatorKronecker(
diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py
index e203d1b..7d4c045 100644
--- a/tensorflow/python/kernel_tests/list_ops_test.py
+++ b/tensorflow/python/kernel_tests/list_ops_test.py
@@ -1235,9 +1235,8 @@
l = list_ops.tensor_list_push_back(l, [[0., 1.]])
l = list_ops.tensor_list_push_back(l, [[2.], [4.]])
with self.assertRaisesRegexp(
- errors.InvalidArgumentError,
- r"Tried to concat tensors with unequal shapes: "
- r"\[2\] vs \[1\]"):
+ errors.InvalidArgumentError, r"Incompatible shapes during merge: "
+ r"\[2\] vs. \[1\]"):
t = list_ops.tensor_list_concat(l, element_dtype=dtypes.float32)
self.evaluate(t)
@@ -1298,6 +1297,65 @@
t = list_ops.tensor_list_concat(l1, element_dtype=dtypes.float32)
self.evaluate(t)
+ def testConcatWithUninitializedTensorsUseListElementShape(self):
+ l = list_ops.tensor_list_reserve(
+ element_dtype=dtypes.float32, element_shape=[2, 3], num_elements=3)
+ t = list_ops.tensor_list_concat(l, element_dtype=dtypes.float32)
+ self.assertAllEqual(np.zeros((6, 3)), t)
+
+ def testConcatWithUninitializedTensorsUseProvidedElementShape(self):
+ l = list_ops.tensor_list_reserve(
+ element_dtype=dtypes.float32, element_shape=None, num_elements=3)
+ t = list_ops.tensor_list_concat(
+ l, element_dtype=dtypes.float32, element_shape=(2, 3))
+ self.assertAllEqual(np.zeros((6, 3)), t)
+
+ def testConcatWithUninitializedTensorsUseProvidedElementShapeAndLengths(self):
+ l = list_ops.tensor_list_reserve(
+ element_dtype=dtypes.float32, element_shape=None, num_elements=3)
+ t, _ = gen_list_ops.tensor_list_concat_v2(
+ l,
+ element_dtype=dtypes.float32,
+ element_shape=list_ops._build_element_shape((None, 3)),
+ leading_dims=[2, 3, 5])
+ self.assertAllEqual(np.zeros((10, 3)), t)
+ l = list_ops.tensor_list_set_item(l, 1, [[2., 3.], [4., 5.], [6., 7.]])
+ t, _ = gen_list_ops.tensor_list_concat_v2(
+ l,
+ element_dtype=dtypes.float32,
+ element_shape=list_ops._build_element_shape((None, 2)),
+ leading_dims=[2, 3, 4])
+ self.assertAllEqual([[0., 0.], [0., 0.], [2., 3.], [4., 5.], [6., 7.],
+ [0., 0.], [0., 0.], [0., 0.], [0., 0.]], t)
+
+ def testConcatWithUninitializedTensorsInferShapeFromElements(self):
+ l = list_ops.tensor_list_reserve(
+ element_dtype=dtypes.float32, element_shape=None, num_elements=3)
+ l = list_ops.tensor_list_set_item(l, 1, [[2., 3.], [4., 5.], [6., 7.]])
+ t = list_ops.tensor_list_concat(l, element_dtype=dtypes.float32)
+ self.assertAllEqual([[0., 0.], [0., 0.], [0., 0.], [2., 3.], [4., 5.],
+ [6., 7.], [0., 0.], [0., 0.], [0., 0.]], t)
+
+ def testConcatWithUninitializedTensorsFailsIfNoElementShape(self):
+ l = list_ops.tensor_list_reserve(
+ element_dtype=dtypes.float32, element_shape=None, num_elements=3)
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ r"Trying to concat list with only uninitialized tensors "
+ r"but element_shape_except_first_dim_ is not fully defined"):
+ t = list_ops.tensor_list_concat(l, element_dtype=dtypes.float32)
+ self.evaluate(t)
+
+ def testConcatWithUninitializedTensorsFailsIfNoInputLengths(self):
+ l = list_ops.tensor_list_reserve(
+ element_dtype=dtypes.float32, element_shape=[None, 3], num_elements=3)
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ r"List contains uninitialized tensor at index 0"
+ r" but leading_dims has only 0 elements."):
+ t = list_ops.tensor_list_concat(l, element_dtype=dtypes.float32)
+ self.evaluate(t)
+
def testEvenSplit(self):
def RunTest(input_tensor, lengths, expected_stacked_output):
diff --git a/tensorflow/python/kernel_tests/lookup_ops_test.py b/tensorflow/python/kernel_tests/lookup_ops_test.py
index bbadeab..0b78115 100644
--- a/tensorflow/python/kernel_tests/lookup_ops_test.py
+++ b/tensorflow/python/kernel_tests/lookup_ops_test.py
@@ -1788,9 +1788,10 @@
exported_keys, exported_values = table.export()
# exported data is in the order of the internal map, i.e. undefined
sorted_keys = np.sort(self.evaluate(exported_keys))
- sorted_values = np.sort(self.evaluate(exported_values))
+ sorted_values = np.sort(self.evaluate(exported_values), axis=0)
self.assertAllEqual([b"brain", b"salad", b"surgery"], sorted_keys)
- self.assertAllEqual([[4, 5], [2, 3], [0, 1]], sorted_values)
+ sorted_expected_values = np.sort([[4, 5], [2, 3], [0, 1]], axis=0)
+ self.assertAllEqual(sorted_expected_values, sorted_values)
def testMutableHashTableExportInsert(self):
with self.cached_session():
diff --git a/tensorflow/python/kernel_tests/map_fn_test.py b/tensorflow/python/kernel_tests/map_fn_test.py
index a31db41..41d99ea 100644
--- a/tensorflow/python/kernel_tests/map_fn_test.py
+++ b/tensorflow/python/kernel_tests/map_fn_test.py
@@ -49,7 +49,6 @@
@test_util.with_control_flow_v2
-@test_util.disable_all_xla("This test never passed for XLA")
class MapFnTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes
diff --git a/tensorflow/python/kernel_tests/pool_test.py b/tensorflow/python/kernel_tests/pool_test.py
index 367c94d..78e786f 100644
--- a/tensorflow/python/kernel_tests/pool_test.py
+++ b/tensorflow/python/kernel_tests/pool_test.py
@@ -303,7 +303,6 @@
self.assertLess(err, err_tolerance)
@test_util.run_deprecated_v1
- @test_util.disable_xla("This test never passed for XLA") # Much larger error
def testGradient1D(self):
with self.session(use_gpu=test.is_gpu_available()):
for padding in ["SAME", "VALID"]:
diff --git a/tensorflow/python/kernel_tests/pooling_ops_test.py b/tensorflow/python/kernel_tests/pooling_ops_test.py
index 94a861f..0cd6495 100644
--- a/tensorflow/python/kernel_tests/pooling_ops_test.py
+++ b/tensorflow/python/kernel_tests/pooling_ops_test.py
@@ -730,7 +730,7 @@
t = nn_ops.max_pool(
t, ksize=ksize, strides=strides, padding="SAME").eval()
- @test_util.disable_xla("b/123338077")
+ @test_util.disable_xla("b/123338077") # Passes with XLA
def testDepthwiseMaxPoolInvalidConfigs(self):
self._testDepthwiseMaxPoolInvalidConfig(
[1, 2, 2, 4], [1, 2, 2, 2], [1, 1, 1, 2],
@@ -1211,7 +1211,6 @@
[1, window_rows, window_cols, 1],
[1, row_stride, col_stride, 1], padding)
- @test_util.disable_xla("This test never passed for XLA")
def _testMaxPoolGradDirect(self, input_data, output_backprop,
expected_input_backprop, input_sizes, output_sizes,
window_rows, window_cols, row_stride, col_stride,
@@ -1353,6 +1352,7 @@
use_gpu=use_gpu,
v2=v2)
+ @test_util.disable_xla("b/123923733") # NaNs handled differently
def _testMaxPoolGradDirectWithNans2_1(self):
input_data = [float("nan")] * 16
output_backprop = [11.0, 12.0, 13.0, 15.0, 16.0, 17.0, 19.0, 20.0, 21.0]
@@ -1427,6 +1427,7 @@
else:
del os.environ["TF_ENABLE_MAXPOOL_NANPROP"]
+ @test_util.disable_xla("b/123923733") # NaNs handled differently
def _testMaxPoolGradDirectWithNans2_2(self):
input_data = [float("nan")] * 16
output_backprop = [
@@ -1627,7 +1628,6 @@
use_gpu=use_gpu)
@test_util.run_deprecated_v1
- @test_util.disable_xla("This test never passed for XLA")
def testMaxPoolGradGrad(self):
for (data_format, use_gpu) in GetTestConfigs():
self._testMaxPoolGradGradValidPadding1_1(data_format, use_gpu)
@@ -1662,7 +1662,6 @@
[1, row_stride, col_stride, 1], padding)
@test_util.run_deprecated_v1
- @test_util.disable_xla("This test never passed for XLA")
def testAvgPoolGrad(self):
for (data_format, use_gpu) in GetTestConfigs():
self._testAvgPoolGradValidPadding1_1(data_format, use_gpu)
@@ -1822,7 +1821,7 @@
padding="SAME")
@test_util.run_deprecated_v1
- @test_util.disable_xla("This test never passed for XLA")
+ @test_util.disable_xla("b/123337890") # Error messages differ
def testOpEdgeCases(self):
with self.session(use_gpu=test.is_gpu_available()) as sess:
pool_funcs = [nn_ops.max_pool, nn_ops.avg_pool]
@@ -1901,9 +1900,10 @@
if name_ == "maxpool5":
setattr(
PoolingTest, "testMaxPoolGrad_" + name_,
- test_util.disable_xla("maxpool5 fails while all others pass")(
- GetMaxPoolGradTest(input_size_, filter_size_, output_size_,
- stride_, padding_)))
+ test_util.disable_xla(
+ "b/123926014: incorrect output with only constants")(
+ GetMaxPoolGradTest(input_size_, filter_size_, output_size_,
+ stride_, padding_)))
else:
setattr(
PoolingTest, "testMaxPoolGrad_" + name_,
diff --git a/tensorflow/python/kernel_tests/qr_op_test.py b/tensorflow/python/kernel_tests/qr_op_test.py
index 5adb95c..f9b221a 100644
--- a/tensorflow/python/kernel_tests/qr_op_test.py
+++ b/tensorflow/python/kernel_tests/qr_op_test.py
@@ -67,8 +67,8 @@
val = self.evaluate(all_ops)
for i in range(8):
q = 4 * i
- self.assertAllEqual(val[q], val[q + 2]) # q1 == q2
- self.assertAllEqual(val[q + 1], val[q + 3]) # r1 == r2
+ self.assertAllClose(val[q], val[q + 2]) # q1 == q2
+ self.assertAllClose(val[q + 1], val[q + 3]) # r1 == r2
def _GetQrOpTest(dtype_, shape_, full_matrices_, use_static_shape_):
diff --git a/tensorflow/python/kernel_tests/reduction_ops_test_big.py b/tensorflow/python/kernel_tests/reduction_ops_test_big.py
index 1e8524f..73bba54 100644
--- a/tensorflow/python/kernel_tests/reduction_ops_test_big.py
+++ b/tensorflow/python/kernel_tests/reduction_ops_test_big.py
@@ -21,16 +21,17 @@
import numpy as np
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-
class BaseReductionTest(test.TestCase):
def _tf_reduce(self, x, reduction_axes, keepdims):
raise NotImplementedError()
+@test_util.disable_all_xla("b/123864762") # Test times out
class BigReductionTest(BaseReductionTest):
"""Test reductions for sum and boolean all over a wide range of shapes."""
diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py
index 3b89249..ca02aa6 100644
--- a/tensorflow/python/kernel_tests/relu_op_test.py
+++ b/tensorflow/python/kernel_tests/relu_op_test.py
@@ -86,7 +86,7 @@
self.assertAllClose(np_relu, tf_relu)
self.assertShapeEqual(np_relu, tf_relu)
- @test_util.disable_xla("This test never passed for XLA")
+ @test_util.disable_xla("b/123338077") # Passes with XLA
def testReluInt8x4BadShape(self):
if not test.is_gpu_available(cuda_only=True):
self.skipTest("No GPU available")
diff --git a/tensorflow/python/kernel_tests/scan_ops_test.py b/tensorflow/python/kernel_tests/scan_ops_test.py
index 4e15894..2a3021f 100644
--- a/tensorflow/python/kernel_tests/scan_ops_test.py
+++ b/tensorflow/python/kernel_tests/scan_ops_test.py
@@ -70,7 +70,6 @@
return x
-@test_util.disable_all_xla("This test never passed for XLA")
class CumsumTest(test.TestCase):
valid_dtypes = [
@@ -135,6 +134,7 @@
self._compareAll(x, axis)
@test_util.run_deprecated_v1
+ @test_util.disable_xla("b/123860949") # The computation is constant folded
def testLarge(self):
for dtype in self.valid_dtypes:
x = np.ones([1000000], dtype=dtype) / 1024
@@ -194,7 +194,6 @@
self._compareGradient([5, 10], axis, exclusive, reverse)
-@test_util.disable_all_xla("This test never passed for XLA")
class CumprodTest(test.TestCase):
valid_dtypes = [
diff --git a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
index 88f7b27..5bc301b 100644
--- a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
+++ b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
@@ -296,7 +296,7 @@
updates).get_shape().as_list(), shape)
@test_util.run_v1_only("b/120545219")
- @test_util.disable_xla("This test never passed for XLA")
+ @test_util.disable_xla("b/123337890") # Error messages differ
def testResVarInvalidOutputShape(self):
res = variables.Variable(
initial_value=lambda: array_ops.zeros(shape=[], dtype=dtypes.float32),
diff --git a/tensorflow/python/kernel_tests/signal/dct_ops_test.py b/tensorflow/python/kernel_tests/signal/dct_ops_test.py
index e698afd..a3ac15b 100644
--- a/tensorflow/python/kernel_tests/signal/dct_ops_test.py
+++ b/tensorflow/python/kernel_tests/signal/dct_ops_test.py
@@ -134,7 +134,6 @@
@parameterized.parameters([
[[2]], [[3]], [[10]], [[2, 20]], [[2, 3, 25]]])
@test_util.run_deprecated_v1
- @test_util.disable_xla("This test never passed for XLA")
def test_random(self, shape):
"""Test randomly generated batches of data."""
with spectral_ops_test_util.fft_kernel_label_map():
diff --git a/tensorflow/python/kernel_tests/signal/fft_ops_test.py b/tensorflow/python/kernel_tests/signal/fft_ops_test.py
index f3bee87..4577587 100644
--- a/tensorflow/python/kernel_tests/signal/fft_ops_test.py
+++ b/tensorflow/python/kernel_tests/signal/fft_ops_test.py
@@ -159,7 +159,6 @@
raise ValueError("invalid rank")
@test_util.run_deprecated_v1
- @test_util.disable_xla("This test never passed for XLA")
def testEmpty(self):
with spectral_ops_test_util.fft_kernel_label_map():
for np_type in (np.complex64, np.complex128):
@@ -170,7 +169,6 @@
self.assertEqual(x.shape, self._tfIFFT(x, rank).shape)
@test_util.run_deprecated_v1
- @test_util.disable_xla("This test never passed for XLA")
def testBasic(self):
with spectral_ops_test_util.fft_kernel_label_map():
for np_type, tol in ((np.complex64, 1e-4), (np.complex128, 1e-8)):
@@ -180,7 +178,6 @@
np.mod(np.arange(np.power(4, dims)), 10).reshape(
(4,) * dims).astype(np_type), rank, rtol=tol, atol=tol)
- @test_util.disable_xla("This test never passed for XLA")
def testLargeBatch(self):
if test.is_gpu_available(cuda_only=True):
rank = 1
@@ -212,7 +209,6 @@
rank, use_placeholder=True, rtol=tol, atol=tol)
@test_util.run_deprecated_v1
- @test_util.disable_xla("This test never passed for XLA")
def testRandom(self):
with spectral_ops_test_util.fft_kernel_label_map():
for np_type, tol in ((np.complex64, 1e-4), (np.complex128, 5e-6)):
@@ -228,7 +224,6 @@
rtol=tol, atol=tol)
@test_util.run_deprecated_v1
- @test_util.disable_xla("This test never passed for XLA")
def testRandom1D(self):
with spectral_ops_test_util.fft_kernel_label_map():
for np_type in (np.complex64, np.complex128):
@@ -345,7 +340,6 @@
raise ValueError("invalid rank")
@test_util.run_deprecated_v1
- @test_util.disable_xla("This test never passed for XLA")
def testEmpty(self):
with spectral_ops_test_util.fft_kernel_label_map():
for rank in VALID_FFT_RANKS:
@@ -356,7 +350,6 @@
self.assertEqual(x.shape, self._tfIFFT(x, rank).shape)
@test_util.run_deprecated_v1
- @test_util.disable_xla("This test never passed for XLA")
def testBasic(self):
with spectral_ops_test_util.fft_kernel_label_map():
for rank in VALID_FFT_RANKS:
@@ -371,7 +364,6 @@
self._compareBackward(
c2r.astype(np.complex64), rank, (size,) * rank)
- @test_util.disable_xla("This test never passed for XLA")
def testLargeBatch(self):
if test.is_gpu_available(cuda_only=True):
rank = 1
@@ -405,7 +397,6 @@
rank, (size,) * rank,
use_placeholder=True)
- @test_util.disable_xla("This test never passed for XLA")
def testFftLength(self):
if test.is_gpu_available(cuda_only=True):
with spectral_ops_test_util.fft_kernel_label_map():
@@ -449,7 +440,6 @@
use_placeholder=True)
@test_util.run_deprecated_v1
- @test_util.disable_xla("This test never passed for XLA")
def testRandom(self):
with spectral_ops_test_util.fft_kernel_label_map():
def gen_real(shape):
@@ -475,7 +465,7 @@
gen_complex(complex_dims), rank, (size,) * rank)
@test_util.run_deprecated_v1
- @test_util.disable_xla("This test never passed for XLA")
+ @test_util.disable_xla("b/123738986") # More assertions needed.
def testError(self):
with spectral_ops_test_util.fft_kernel_label_map():
for rank in VALID_FFT_RANKS:
diff --git a/tensorflow/python/kernel_tests/signal/spectral_ops_test.py b/tensorflow/python/kernel_tests/signal/spectral_ops_test.py
index a72a836..7b9748c 100644
--- a/tensorflow/python/kernel_tests/signal/spectral_ops_test.py
+++ b/tensorflow/python/kernel_tests/signal/spectral_ops_test.py
@@ -21,7 +21,6 @@
import numpy as np
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import math_ops
@@ -116,7 +115,6 @@
self.assertAllClose(
expected_inverse_stft, actual_inverse_stft, 1e-4, 1e-4)
- @test_util.disable_xla("This test never passed for XLA")
def test_shapes(self):
with spectral_ops_test_util.fft_kernel_label_map(), (
self.session(use_gpu=True)):
@@ -152,7 +150,6 @@
self.assertAllEqual([256], inverse_stft.shape.as_list())
self.assertAllEqual([expected_length], self.evaluate(inverse_stft).shape)
- @test_util.disable_xla("This test never passed for XLA")
def test_stft_and_inverse_stft(self):
"""Test that spectral_ops.stft/inverse_stft match a NumPy implementation."""
# Tuples of (signal_length, frame_length, frame_step, fft_length).
diff --git a/tensorflow/python/kernel_tests/spacetodepth_op_test.py b/tensorflow/python/kernel_tests/spacetodepth_op_test.py
index 7f3c381..69243af 100644
--- a/tensorflow/python/kernel_tests/spacetodepth_op_test.py
+++ b/tensorflow/python/kernel_tests/spacetodepth_op_test.py
@@ -285,7 +285,7 @@
actual_vals, expected_vals = self.evaluate([actual, expected])
self.assertTrue(np.array_equal(actual_vals, expected_vals))
- @test_util.disable_xla("This test never passed for XLA")
+ @test_util.disable_xla("b/123553551") # Unsupported data format
def testAgainstTranspose(self):
self.compareToTranspose(3, 2, 3, 1, 2, "NHWC", False)
self.compareToTranspose(1, 2, 3, 2, 2, "NHWC", False)
diff --git a/tensorflow/python/kernel_tests/split_op_test.py b/tensorflow/python/kernel_tests/split_op_test.py
index 80004db..42b4d1b 100644
--- a/tensorflow/python/kernel_tests/split_op_test.py
+++ b/tensorflow/python/kernel_tests/split_op_test.py
@@ -373,7 +373,7 @@
assert s1.shape.as_list() == [1]
@test_util.run_deprecated_v1
- @test_util.disable_xla("This test never passed for XLA")
+ @test_util.disable_xla("b/123337890") # Error messages differ
def testNonexistentDimTensor(self):
x = array_ops.placeholder(dtypes.int32)
values = np.zeros([5, 30])
diff --git a/tensorflow/python/kernel_tests/stack_op_test.py b/tensorflow/python/kernel_tests/stack_op_test.py
index ca3357a..04d635c 100644
--- a/tensorflow/python/kernel_tests/stack_op_test.py
+++ b/tensorflow/python/kernel_tests/stack_op_test.py
@@ -81,7 +81,7 @@
np.random.seed(7)
with self.session(use_gpu=True):
for shape in (2,), (3,), (2, 3), (3, 2), (4, 3, 2):
- for dtype in [np.bool, np.float32, np.int32, np.int64]:
+ for dtype in [np.bool, np.float32, np.int16, np.int32, np.int64]:
data = np.random.randn(*shape).astype(dtype)
# Stack back into a single tensorflow tensor directly using np array
c = array_ops.stack(data)
diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
index d0efb47..b4544ba 100644
--- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py
+++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py
@@ -185,7 +185,6 @@
self.assertAllEqual([[0.0, 0.0], [4.0, 5.0], [0.0, 0.0]],
self.evaluate(ta.write(1, [[4.0, 5.0]]).concat()))
- @test_util.disable_control_flow_v2("b/122324791")
@test_util.run_v1_only("b/122324791")
def testTensorArrayReadOrPackNotAllValuesAvailableFillsZeros(self):
self._testTensorArrayReadOrPackNotAllValuesAvailableFillsZeros()
@@ -202,11 +201,21 @@
self.assertAllEqual([[0.0, 0.0], [4.0, 5.0], [0.0, 0.0]],
self.evaluate(ta.write(1, [[4.0, 5.0]]).concat()))
- @test_util.disable_control_flow_v2("b/122324791")
@test_util.run_v1_only("b/122324791")
def testTensorArrayReadOrPackNotAllValuesAvailableInferShapeFillsZeros(self):
self._testTensorArrayReadOrPackNotAllValuesAvailableInferShapeFillsZeros()
+ @test_util.run_v1_only("Uses placeholders")
+ def testSkipEagerTensorArrayReadUninitializedInferShapeFillsZeros(self):
+ with self.cached_session(use_gpu=True) as sess:
+ ta = tensor_array_ops.TensorArray(
+ dtype=dtypes.float32,
+ tensor_array_name="foo",
+ size=3)
+ val = array_ops.placeholder(dtypes.float32)
+ self.assertAllEqual(
+ [[0.0, 0.0]], sess.run(ta.write(1, val).read(0), {val: [[4.0, 5.0]]}))
+
def _testTensorArrayUnpackRead(self, tf_dtype):
with self.cached_session(use_gpu=True):
convert = _make_converter(tf_dtype)
@@ -531,7 +540,10 @@
# The exact error messages differ between eager execution and graph
# construction as the former bubbles up the error from array_op.concat.
- with self.assertRaisesOpError("shape"):
+ error_msg = ("Incompatible ranks"
+ if control_flow_util.ENABLE_CONTROL_FLOW_V2 and
+ not context.executing_eagerly() else "shape")
+ with self.assertRaisesRegexp(errors.InvalidArgumentError, error_msg):
self.evaluate(w3.concat())
def testTensorArraySplitIncompatibleShapesFails(self):
@@ -1501,7 +1513,7 @@
if "/task:1/" in d:
self.assertTrue(
[s for s in dev_stats[d] if "/TensorArray" in s.node_name])
- else:
+ elif "/host:CPU" not in d:
self.assertFalse(
[s for s in dev_stats[d] if "/TensorArray" in s.node_name])
diff --git a/tensorflow/python/layers/core.py b/tensorflow/python/layers/core.py
index b2d54a9..7e12dca 100644
--- a/tensorflow/python/layers/core.py
+++ b/tensorflow/python/layers/core.py
@@ -64,7 +64,7 @@
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
name: String, the name of the layer. Layers with the same name will
share weights, but to avoid mistakes we require reuse=True in such cases.
- reuse: Boolean, whether to reuse the weights of a previous layer
+ _reuse: Boolean, whether to reuse the weights of a previous layer
by the same name.
Properties:
diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc
index f681cff..6cdf7e7 100644
--- a/tensorflow/python/lib/core/py_seq_tensor.cc
+++ b/tensorflow/python/lib/core/py_seq_tensor.cc
@@ -64,6 +64,19 @@
PyIsInstance(obj, &PyFloatingArrType_Type); // NumPy float types
}
+// If the input is a zero dimensional PyArray return it converted to a scalar.
+// Otherwise return the input and increment its reference count.
+// Users must Py_DECREF the output of this method.
+PyObject* ZeroDimArrayToScalar(PyObject* obj) {
+ if (PyArray_IsZeroDim(obj) && !PyArray_IsScalar(obj, Generic)) {
+ auto pyarray_obj = reinterpret_cast<PyArrayObject*>(obj);
+ obj = PyArray_ToScalar(PyArray_DATA(pyarray_obj), pyarray_obj);
+ } else {
+ Py_INCREF(obj);
+ }
+ return obj;
+}
+
// Converts Python object `c` that should hold a Python string into a
// C++ string in *out. Returns nullptr on success, or a message on error.
// Defined below, but forward declared here for use in PyRepr.
@@ -130,6 +143,10 @@
Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) {
std::vector<Safe_PyObjectPtr> refs_to_clean;
while (true) {
+ // Convert any zero dimensional numpy arrays to scalars first of all.
+ // We also have to make sure a reference to the safe_obj is kept.
+ obj = ZeroDimArrayToScalar(obj);
+ refs_to_clean.push_back(make_safe(obj));
// We test strings first, in case a string is considered a sequence.
if (IsPyString(obj)) {
*dtype = DT_STRING;
@@ -240,7 +257,9 @@
} \
PyObject** l = PySequence_Fast_ITEMS(seq.get()); \
for (int64 i = 0; i < s; ++i) { \
- const char* error = CONVERT(l[i], *buf); \
+ auto scalar = ZeroDimArrayToScalar(l[i]); \
+ const char* error = CONVERT(scalar, *buf); \
+ Py_DECREF(scalar); \
if (TF_PREDICT_FALSE(error != nullptr)) return error; \
++*buf; \
} \
@@ -253,7 +272,9 @@
Tensor result(TYPE_ENUM, shape); \
if (shape.dims() == 0) { /* Scalar case */ \
TYPE value; \
- const char* error = CONVERT(obj, &value); \
+ auto scalar = ZeroDimArrayToScalar(obj); \
+ const char* error = CONVERT(scalar, &value); \
+ Py_DECREF(scalar); \
if (error != nullptr) return error; \
result.scalar<TYPE>()() = value; \
} else { \
diff --git a/tensorflow/python/module/module.py b/tensorflow/python/module/module.py
index 62d1bd4..6bd2755 100644
--- a/tensorflow/python/module/module.py
+++ b/tensorflow/python/module/module.py
@@ -131,7 +131,7 @@
>>> class Dense(tf.Module):
... def __init__(self, in_features, output_features):
- ... super(Linear, self).__init__()
+ ... super(Dense, self).__init__()
... self.w = tf.Variable(
... tf.random_normal([input_features, output_features]), name='w')
... self.b = tf.Variable(tf.zeros([output_features]), name='b')
@@ -150,7 +150,7 @@
By subclassing `tf.Module` instead of `object` any variables created inside
the module are automatically created within the modules name scope:
- >> d.w.name
+ >>> d.w.name
"dense/w:0"
In eager mode this is useful for debugging, and when used with `@tf.function`
@@ -199,85 +199,35 @@
@property
def variables(self):
- """Collection of variables owned by this module and it's submodules.
+ """Sequence of variables owned by this module and it's submodules.
Note: this method uses reflection to find variables on the current instance
and submodules. For performance reasons you may wish to cache the result
of calling this method if you don't expect the return value to change.
Returns:
- A collection of variables for the current module (sorted by attribute
+ A sequence of variables for the current module (sorted by attribute
name) followed by variables from all submodules recursively (depth first).
"""
- return tuple(walk(self, recurse_if=_IS_MODULE, predicate=_IS_VARIABLE))
-
- @property
- def owned_variables(self):
- """Collection of variables that are attributes of the current module.
-
- See `variables` for a property which returns all variables from the current
- module and all it's submodules recursively.
-
- Returns:
- A collection of variables which are attributes of the current module. Will
- yield variables inside nested structures (lists etc) but not in other
- modules.
- """
- return tuple(walk(self, predicate=_IS_VARIABLE))
+ return tuple(self._flatten(predicate=_IS_VARIABLE))
@property
def trainable_variables(self):
- """Collection of variables owned by this module and it's submodules.
+ """Sequence of variables owned by this module and it's submodules.
Note: this method uses reflection to find variables on the current instance
and submodules. For performance reasons you may wish to cache the result
of calling this method if you don't expect the return value to change.
Returns:
- A collection of variables for the current module (sorted by attribute
+ A sequence of variables for the current module (sorted by attribute
name) followed by variables from all submodules recursively (depth first).
"""
- return tuple(
- walk(self, recurse_if=_IS_MODULE, predicate=_IS_TRAINABLE_VARIABLE))
-
- @property
- def owned_trainable_variables(self):
- """Collection of variables that are attributes of the current module.
-
- See `variables` for a property which returns all variables from the current
- module and all it's submodules recursively.
-
- Returns:
- A collection of variables which are attributes of the current module. Will
- yield variables inside nested structures (lists etc) but not in other
- modules.
- """
- return tuple(walk(self, predicate=_IS_TRAINABLE_VARIABLE))
-
- @property
- def owned_submodules(self):
- """Collection of immediate child modules.
-
- Child modules are modules which are found as properties of the current
- module.
-
- >>> a = tf.experimental.Module()
- >>> b = tf.experimental.Module()
- >>> c = tf.experimental.Module()
- >>> a.b = b
- >>> b.c = c
- >>> assert list(a.owned_submodules) == [b]
- >>> assert list(b.owned_submodules) == [c]
- >>> assert list(c.owned_submodules) == []
-
- Returns:
- A collection of all child modules.
- """
- return tuple(walk(self, predicate=_IS_MODULE))
+ return tuple(self._flatten(predicate=_IS_TRAINABLE_VARIABLE))
@property
def submodules(self):
- """Collection of all sub-modules.
+ """Sequence of all sub-modules.
Submodules are modules which are properties of this module, or found as
properties of modules which are properties of this module (and so on).
@@ -292,9 +242,62 @@
>>> assert list(c.submodules) == []
Returns:
- A collection of all submodules.
+ A sequence of all submodules.
"""
- return tuple(walk(self, recurse_if=_IS_MODULE, predicate=_IS_MODULE))
+ return tuple(self._flatten(predicate=_IS_MODULE))
+
+ def _flatten(self,
+ recursive=True,
+ predicate=None,
+ attribute_traversal_key=None):
+ """Flattened attribute values in sorted order by attribute name.
+
+ Modules are flattened by first walking their attributes in name order.
+ Each attribute value is then flattened to find leaf values. If flatten is
+ to be applied `recursive`ly then if the leaf is a `Module` it will also be
+ flattened to find leaves. Finally every leaf value is optionally tested
+ against the given `predicate` and finally yielded.
+
+ >>> class Foo(tf.experimental.Module):
+ ... def __init__(self):
+ ... super(Foo, self).__init__()
+ ... self.x = [tf.constant('a'), tf.constant('b')]
+ ... self.y = {'i': tf.constant('c'), 'j': tf.constant('d')}
+ ... self.z = tf.constant('e')
+ ...
+ ... @property
+ ... def tensors(self):
+ ... return tuple(self._flatten(predicate=is_tensor))
+
+ >>> foo = Foo()
+ >>> foo.tensors
+ (<tf.Tensor...'a'>, <tf.Tensor...'b'>, ...'c'>, ...'d'>, ...'e'>)
+
+ `attribute_traversal_key` controls the order object properties are visited.
+ If not set objects are visited in ascending order by name.
+
+ Args:
+ recursive: Whether to recurse into child modules or not.
+ predicate: (Optional) If set then only values matching predicate are
+ yielded. A value of `None` (the default) means no items will be
+ filtered.
+ attribute_traversal_key: (Optional) Method to rekey object attributes
+ before they are sorted. Contract is the same as `key` argument to
+ builtin `sorted` and only applies to object properties.
+
+ Returns:
+ Flat generator for leaves of the current module and optionally all
+ submodules.
+ """
+ if predicate is None:
+ predicate = lambda _: True
+
+ return _flatten_module(
+ self,
+ recursive=recursive,
+ predicate=predicate,
+ attribute_traversal_key=attribute_traversal_key,
+ seen=set())
@classmethod
def no_name_scope(cls, method):
@@ -334,77 +337,37 @@
return _CAMEL_TO_SNAKE_R.sub(r"_\1", value).lower()
-def walk(o, recurse_if=None, predicate=None):
- """Flattened attributes of `o` in sorted order by attribute name.
-
- >>> class Foo(object):
- ... def __init__(self, prefix=''):
- ... self.z = prefix + 'c'
- ... self.a = [prefix + 'a', prefix + 'b']
-
- >>> tuple(walk(Foo()))
- ('a', 'b', 'c')
-
- If `predicate` is not None, then only values matching predicate are returned:
-
- >>> tuple(walk(Foo(), predicate=lambda v: v != 'a'))
- ('b', 'c')
-
- If `recurse_if` is not None then it should be a callable which tests if the
- given leaf should be expanded:
-
- >>> is_string = lambda v: isinstance(v, str)
- >>> is_foo = lambda l: isinstance(l, Foo)
- >>> o = Foo(prefix='root_')
- >>> o.b = Foo(prefix='child_')
- >>> tuple(walk(o, predicate=is_string))
- ('root_a', 'root_b', 'root_c')
- >>> tuple(walk(o, recurse_if=is_foo, predicate=is_string))
- ('root_a', 'root_b', 'root_c', 'child_a', 'child_b', 'child_c')
-
- Args:
- o: An object who's attributes are walked.
- recurse_if: (Optional) Visited items of this type will be walked to extract
- more leaves. If `None`, it will not recurse into leaves.
- predicate: (Optional) If set then only values matching predicate are
- yielded.
-
- Returns:
- Attributes of `o` in name order. If `recurse_if` is not `None` then
- attributes for which `recurse_if(attribute) == True` will be walked
- recursively. If `predicate` is not `None` then only attributes for which
- `predicate(attribute) == True` will be yielded.
- """
- if predicate is None:
- predicate = lambda _: True
- return _walk_internal(
- o, recurse_if=recurse_if, predicate=predicate, seen=set())
-
-
-def _walk_internal(o, recurse_if, predicate, seen):
- """Implementation of `walk`."""
+def _flatten_module(module, recursive, predicate, attribute_traversal_key,
+ seen):
+ """Implementation of `flatten`."""
if seen is None:
- seen = set([id(o)])
+ seen = set([id(module)])
- o_dict = vars(o)
- to_walk = []
+ module_dict = vars(module)
+ submodules = []
- for key in sorted(o_dict):
- values = nest.flatten(o_dict[key])
- for value in values:
- value_id = id(value)
- if value_id in seen:
+ for key in sorted(module_dict, key=attribute_traversal_key):
+ for leaf in nest.flatten(module_dict[key]):
+ leaf_id = id(leaf)
+ if leaf_id in seen:
continue
- seen.add(value_id)
- if predicate(value):
- yield value
+ seen.add(leaf_id)
+ if predicate(leaf):
+ yield leaf
- if recurse_if is not None and recurse_if(value):
+ if recursive and isinstance(leaf, Module):
# Walk direct properties first then recurse.
- to_walk.append(value)
+ submodules.append(leaf)
- for value in to_walk:
- for subvalue in _walk_internal(value, recurse_if, predicate, seen):
+ for submodule in submodules:
+ subvalues = _flatten_module(
+ submodule,
+ recursive=recursive,
+ predicate=predicate,
+ attribute_traversal_key=attribute_traversal_key,
+ seen=seen)
+
+ for subvalue in subvalues:
# Predicate is already tested for these values.
yield subvalue
diff --git a/tensorflow/python/module/module_test.py b/tensorflow/python/module/module_test.py
index 1e4e195..da9bcf7 100644
--- a/tensorflow/python/module/module_test.py
+++ b/tensorflow/python/module/module_test.py
@@ -151,12 +151,6 @@
self.assertEqual(m.child.variables, (m.child.w, m.child.child.w))
self.assertEqual(m.child.child.variables, (m.child.child.w,))
- def test_owned_variables(self):
- m = RecursiveModule(3)
- self.assertEqual(m.owned_variables, (m.w,))
- self.assertEqual(m.child.owned_variables, (m.child.w,))
- self.assertEqual(m.child.child.owned_variables, (m.child.child.w,))
-
def test_trainable_variables(self):
m = RecursiveModule(3)
self.assertEqual(m.trainable_variables,
@@ -171,28 +165,9 @@
self.assertEqual(len(m.child.trainable_variables), 0)
self.assertEqual(len(m.child.child.trainable_variables), 0)
- def test_owned_trainable_variables(self):
- m = RecursiveModule(3)
- self.assertEqual(m.owned_trainable_variables, (m.w,))
- self.assertEqual(m.child.owned_trainable_variables, (m.child.w,))
- self.assertEqual(m.child.child.owned_trainable_variables,
- (m.child.child.w,))
-
- def test_owned_trainable_variables_ignores_non_trainable(self):
- m = RecursiveModule(3, trainable=False)
- self.assertEqual(len(m.owned_trainable_variables), 0)
- self.assertEqual(len(m.child.owned_trainable_variables), 0)
- self.assertEqual(len(m.child.child.owned_trainable_variables), 0)
-
class ModuleTrackingTest(test.TestCase):
- def test_owned_submodules(self):
- m = RecursiveModule(3)
- self.assertEqual(list(m.owned_submodules), [m.child])
- self.assertEqual(list(m.child.owned_submodules), [m.child.child])
- self.assertEqual(list(m.child.child.owned_submodules), [])
-
def test_submodules(self):
m = RecursiveModule(3)
self.assertEqual(list(m.submodules), [m.child, m.child.child])
@@ -332,13 +307,43 @@
child = parent.c
self.assertEqual(
- list(module.walk(parent, predicate=IS_MEMBER)),
+ list(parent._flatten(recursive=False, predicate=IS_MEMBER)),
[parent.a[0], parent.a[1], parent.z])
self.assertEqual(
- list(module.walk(parent, recurse_if=IS_MODULE, predicate=IS_MEMBER)),
+ list(parent._flatten(predicate=IS_MEMBER)),
[parent.a[0], parent.a[1], parent.z, child.a[0], child.a[1], child.z])
+ def test_attribute_traversal_key(self):
+ mod = LayerModule()
+ self.assertEqual(
+ mod.variables,
+ mod._trainable_variables + mod._non_trainable_variables + [mod._bonus])
+
+
+class LayerModule(module.Module):
+
+ def __init__(self):
+ super(LayerModule, self).__init__()
+ self._trainable_variables = [
+ variables.Variable(1., name="a"),
+ variables.Variable(2., name="b"),
+ ]
+ self._non_trainable_variables = [
+ variables.Variable(3., name="c"),
+ variables.Variable(4., name="d"),
+ ]
+ self._bonus = variables.Variable(5., name="e")
+
+ @property
+ def variables(self):
+ def key_function(name):
+ indexes = {"_trainable_variables": 0, "_non_trainable_variables": 1}
+ return indexes.get(name, 2), name
+
+ return list(self._flatten(predicate=module._IS_VARIABLE,
+ attribute_traversal_key=key_function))
+
class MemberType(object):
"""A simple type to search for."""
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index 89d6de6..230ce62 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -3375,7 +3375,7 @@
@dispatch.add_dispatch_support
@deprecation.deprecated(
"2017-10-25", "`tf.batch_gather` is deprecated, please use `tf.gather` "
- "with `batch_dims=-1` instead.") # pylint: disable=missing-docstring
+ "with `batch_dims` instead.") # pylint: disable=missing-docstring
def batch_gather(params, indices, name=None):
"""Gather slices from params according to indices with leading batch dims."""
with ops.name_scope(name, "BatchGather", [params, indices]):
diff --git a/tensorflow/python/ops/boosted_trees_ops.py b/tensorflow/python/ops/boosted_trees_ops.py
index 37d649a..f6c3702 100644
--- a/tensorflow/python/ops/boosted_trees_ops.py
+++ b/tensorflow/python/ops/boosted_trees_ops.py
@@ -61,7 +61,36 @@
sorted(cls._map))))
-class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject):
+class QuantileAccumulatorSaveable(saver.BaseSaverBuilder.SaveableObject):
+ """SaveableObject implementation for QuantileAccumulator."""
+
+ def __init__(self, resource_handle, create_op, num_streams, name):
+ self._resource_handle = resource_handle
+ self._num_streams = num_streams
+ self._create_op = create_op
+ bucket_boundaries = get_bucket_boundaries(self._resource_handle,
+ self._num_streams)
+ slice_spec = ''
+ specs = []
+
+ def make_save_spec(tensor, suffix):
+ return saver.BaseSaverBuilder.SaveSpec(tensor, slice_spec, name + suffix)
+
+ for i in range(self._num_streams):
+ specs += [
+ make_save_spec(bucket_boundaries[i], '_bucket_boundaries_' + str(i))
+ ]
+ super(QuantileAccumulatorSaveable, self).__init__(self._resource_handle,
+ specs, name)
+
+ def restore(self, restored_tensors, unused_tensor_shapes):
+ bucket_boundaries = restored_tensors
+ with ops.control_dependencies([self._create_op]):
+ return quantile_resource_deserialize(
+ self._resource_handle, bucket_boundaries=bucket_boundaries)
+
+
+class QuantileAccumulator(tracking.TrackableResource):
"""SaveableObject implementation for QuantileAccumulator.
The bucket boundaries are serialized and deserialized from checkpointing.
@@ -73,55 +102,58 @@
num_quantiles,
name=None,
max_elements=None):
+ self._eps = epsilon
+ self._num_streams = num_streams
+ self._num_quantiles = num_quantiles
+ super(QuantileAccumulator, self).__init__()
+
with ops.name_scope(name, 'QuantileAccumulator') as name:
- self._eps = epsilon
- self._num_streams = num_streams
- self._num_quantiles = num_quantiles
- self._resource_handle = quantile_resource_handle_op(
- container='', shared_name=name, name=name)
- self._create_op = create_quantile_stream_resource(self._resource_handle,
- epsilon, num_streams)
- is_initialized_op = is_quantile_resource_initialized(
- self._resource_handle)
- resources.register_resource(self._resource_handle, self._create_op,
- is_initialized_op)
- self._make_saveable(name)
+ self._name = name
+ self._resource_handle = self.create_resource()
+ self._init_op = self.initialize()
+ is_initialized_op = self.is_initialized()
+ resources.register_resource(self.resource_handle, self._init_op,
+ is_initialized_op)
+ self._saveable = QuantileAccumulatorSaveable(
+ self.resource_handle, self._init_op, self._num_streams,
+ self.resource_handle.name)
+ ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable)
- def _make_saveable(self, name):
- bucket_boundaries = get_bucket_boundaries(self._resource_handle,
- self._num_streams)
- slice_spec = ''
- specs = []
- for i in range(self._num_streams):
- specs.append(
- saver.BaseSaverBuilder.SaveSpec(
- bucket_boundaries[i], slice_spec,
- name + '_bucket_boundaries_' + str(i)))
- super(QuantileAccumulator, self).__init__(self._resource_handle, specs,
- name)
- ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self)
+ def create_resource(self):
+ return quantile_resource_handle_op(
+ container='', shared_name=self._name, name=self._name)
- def restore(self, restored_tensors, unused_tensor_shapes):
- bucket_boundaries = restored_tensors
- with ops.control_dependencies([self._create_op]):
- return quantile_resource_deserialize(
- self._resource_handle, bucket_boundaries=bucket_boundaries)
+ def initialize(self):
+ return create_quantile_stream_resource(self.resource_handle, self._eps,
+ self._num_streams)
+
+ @property
+ def initializer(self):
+ if self._init_op is None:
+ self._init_op = self.initialize()
+ return self._init_op
+
+ def is_initialized(self):
+ return is_quantile_resource_initialized(self.resource_handle)
+
+ @property
+ def saveable(self):
+ return self._saveable
+
+ def _gather_saveables_for_checkpoint(self):
+ return {'quantile_accumulator', self._saveable}
def add_summaries(self, float_columns, example_weights):
summaries = make_quantile_summaries(float_columns, example_weights,
self._eps)
- summary_op = quantile_add_summaries(self._resource_handle, summaries)
+ summary_op = quantile_add_summaries(self.resource_handle, summaries)
return summary_op
def flush(self):
- return quantile_flush(self._resource_handle, self._num_quantiles)
+ return quantile_flush(self.resource_handle, self._num_quantiles)
def get_bucket_boundaries(self):
- return get_bucket_boundaries(self._resource_handle, self._num_streams)
-
- @property
- def resource(self):
- return self._resource_handle
+ return get_bucket_boundaries(self.resource_handle, self._num_streams)
class _TreeEnsembleSavable(saver.BaseSaverBuilder.SaveableObject):
diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py
index f1f3626..b452b4a 100644
--- a/tensorflow/python/ops/check_ops.py
+++ b/tensorflow/python/ops/check_ops.py
@@ -1526,6 +1526,25 @@
v1=['debugging.is_numeric_tensor', 'is_numeric_tensor'])
@deprecation.deprecated_endpoints('is_numeric_tensor')
def is_numeric_tensor(tensor):
+ """Returns `True` if the elements of `tensor` are numbers.
+
+ Specifically, returns `True` if the dtype of `tensor` is one of the following:
+
+ * `tf.float32`
+ * `tf.float64`
+ * `tf.int8`
+ * `tf.int16`
+ * `tf.int32`
+ * `tf.int64`
+ * `tf.uint8`
+ * `tf.qint8`
+ * `tf.qint32`
+ * `tf.quint8`
+ * `tf.complex64`
+
+ Returns `False` if `tensor` is of a non-numeric type or if `tensor` is not
+ a `tf.Tensor` object.
+ """
return isinstance(tensor, ops.Tensor) and tensor.dtype in NUMERIC_TYPES
@@ -1702,7 +1721,7 @@
@tf_export(v1=['debugging.assert_scalar', 'assert_scalar'])
@deprecation.deprecated_endpoints('assert_scalar')
def assert_scalar(tensor, name=None, message=None):
- """Asserts that the given `tensor` is a scalar.
+ """Asserts that the given `tensor` is a scalar (i.e. zero-dimensional).
This function raises `ValueError` unless it can be certain that the given
`tensor` is a scalar. `ValueError` is also raised if the shape of `tensor` is
diff --git a/tensorflow/python/ops/cond_v2.py b/tensorflow/python/ops/cond_v2.py
index a0a13fb..74f5b52 100644
--- a/tensorflow/python/ops/cond_v2.py
+++ b/tensorflow/python/ops/cond_v2.py
@@ -35,7 +35,7 @@
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import gen_functional_ops
from tensorflow.python.ops import gen_resource_variable_ops
-from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import gradients_util
from tensorflow.python.ops import math_ops
from tensorflow.python.util import nest
@@ -277,7 +277,7 @@
ys = []
grad_ys = []
for y, grad_y in zip(func_graph.outputs, grads):
- if not gradients_impl.IsTrainable(y):
+ if not gradients_util.IsTrainable(y):
continue
ys.append(y)
grad_ys.append(grad_y)
@@ -286,7 +286,7 @@
# func_graph in the current graph, which requires capturing tensors from
# func_graph. The captured func_graph tensors are resolved to external tensors
# in _resolve_grad_inputs.
- result = gradients_impl._GradientsHelper(
+ result = gradients_util._GradientsHelper(
ys, func_graph.inputs, grad_ys=grad_ys,
src_graph=func_graph)
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index cfdbe63..e0b83c4 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -59,13 +59,13 @@
from tensorflow.python.util.tf_export import tf_export
# This is to avoid a circular dependency:
-# cond_v2 -> gradients_impl -> control_flow_ops
+# cond_v2 -> gradients_util -> control_flow_ops
cond_v2 = LazyLoader("cond_v2", globals(),
"tensorflow.python.ops.cond_v2")
# This is to avoid circular dependencies:
# while_v2 -> control_flow_ops
-# while_v2 -> gradients_impl -> control_flow_ops
+# while_v2 -> gradients_util -> control_flow_ops
while_v2 = LazyLoader("while_v2", globals(),
"tensorflow.python.ops.while_v2")
diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py
index d0291e2..881466c 100644
--- a/tensorflow/python/ops/embedding_ops.py
+++ b/tensorflow/python/ops/embedding_ops.py
@@ -558,7 +558,7 @@
combiner=None,
max_norm=None,
name=None):
- return embedding_lookup_sparse_v2(
+ return embedding_lookup_sparse(
params, sp_ids, sp_weights, partition_strategy, name, combiner, max_norm)
diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py
index 324f55a..448e45c 100644
--- a/tensorflow/python/ops/functional_ops.py
+++ b/tensorflow/python/ops/functional_ops.py
@@ -397,13 +397,15 @@
ops.convert_to_tensor(elem, name="elem") for elem in elems_flat]
# Convert elems to tensor array. n may be known statically.
- n = (tensor_shape.dimension_value(elems_flat[0].shape[0])
- or array_ops.shape(elems_flat[0])[0])
+ n = tensor_shape.dimension_value(elems_flat[0].shape[0])
+ if n is None:
+ n = array_ops.shape(elems_flat[0])[0]
# TensorArrays are always flat
elems_ta = [
tensor_array_ops.TensorArray(dtype=elem.dtype, size=n,
dynamic_size=False,
+ element_shape=elem.shape[1:],
infer_shape=True)
for elem in elems_flat]
# Unpack elements
diff --git a/tensorflow/python/ops/gradients.py b/tensorflow/python/ops/gradients.py
index cd11447..96389ab 100644
--- a/tensorflow/python/ops/gradients.py
+++ b/tensorflow/python/ops/gradients.py
@@ -22,7 +22,7 @@
from tensorflow.python.eager import function
from tensorflow.python.eager.backprop import GradientTape
from tensorflow.python.ops.custom_gradient import custom_gradient
-from tensorflow.python.ops.gradients_impl import AggregationMethod
+from tensorflow.python.ops.gradients_util import AggregationMethod
from tensorflow.python.ops.gradients_impl import gradients
from tensorflow.python.ops.gradients_impl import hessians
from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index 0a70d6e..c66efad 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -18,30 +18,14 @@
from __future__ import division
from __future__ import print_function
-import collections
-import contextlib
-import warnings
-
-import numpy as np
-import six
-from six.moves import xrange # pylint: disable=redefined-builtin
-
-from tensorflow.core.framework import attr_value_pb2
-from tensorflow.python.eager import context
-from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import function as framework_function
from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.framework import tensor_util
-from tensorflow.python.framework.func_graph import FuncGraph
from tensorflow.python.ops import array_grad # pylint: disable=unused-import
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops # pylint: disable=unused-import
from tensorflow.python.ops import control_flow_grad # pylint: disable=unused-import
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import control_flow_util
-from tensorflow.python.ops import functional_ops
+from tensorflow.python.ops import gradients_util
from tensorflow.python.ops import image_grad # pylint: disable=unused-import
from tensorflow.python.ops import linalg_grad # pylint: disable=unused-import
from tensorflow.python.ops import linalg_ops # pylint: disable=unused-import
@@ -51,503 +35,11 @@
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import optional_grad # pylint: disable=unused-import
from tensorflow.python.ops import random_grad # pylint: disable=unused-import
-from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
-from tensorflow.python.platform import tf_logging as logging
-from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export
-# This is to avoid a circular dependency (eager.function depends on
-# gradients_impl). This is set in eager/function.py.
-_function = None
-
-# Warn the user if we convert a sparse representation to dense with at
-# least this number of elements.
-_LARGE_SPARSE_NUM_ELEMENTS = 100000000
-
-
-def _IndexedSlicesToTensor(value, dtype=None, name=None, as_ref=False):
- """Converts an IndexedSlices object `value` to a Tensor.
-
- NOTE(mrry): This function is potentially expensive.
-
- Args:
- value: An ops.IndexedSlices object.
- dtype: The dtype of the Tensor to be returned.
- name: Optional name to use for the returned Tensor.
- as_ref: True if a ref is requested.
-
- Returns:
- A dense Tensor representing the values in the given IndexedSlices.
-
- Raises:
- ValueError: If the IndexedSlices does not have the same dtype.
- """
- _ = as_ref
- if dtype and not dtype.is_compatible_with(value.dtype):
- raise ValueError(
- "Tensor conversion requested dtype %s for IndexedSlices with dtype %s" %
- (dtype.name, value.dtype.name))
- if value.dense_shape is None:
- raise ValueError(
- "Tensor conversion requested for IndexedSlices without dense_shape: %s"
- % str(value))
- # TODO(mrry): Consider adding static shape information to
- # IndexedSlices, to avoid using numpy here.
- if not context.executing_eagerly():
- dense_shape_value = tensor_util.constant_value(value.dense_shape)
- if dense_shape_value is not None:
- num_elements = np.prod(dense_shape_value)
- if num_elements >= _LARGE_SPARSE_NUM_ELEMENTS:
- warnings.warn(
- "Converting sparse IndexedSlices to a dense Tensor with %d "
- "elements. This may consume a large amount of memory." %
- num_elements)
- else:
- warnings.warn(
- "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
- "This may consume a large amount of memory.")
- return math_ops.unsorted_segment_sum(
- value.values, value.indices, value.dense_shape[0], name=name)
-
-
-ops.register_tensor_conversion_function(ops.IndexedSlices,
- _IndexedSlicesToTensor)
-
-
-def _MarkReachedOps(from_ops, reached_ops, func_graphs):
- """Mark all ops reached from "from_ops".
-
- Args:
- from_ops: list of Operations.
- reached_ops: set of Operations.
- func_graphs: list of FuncGraphs. This method will traverse through
- these functions if they capture from_ops or any reachable ops.
- """
- queue = collections.deque()
- queue.extend(from_ops)
- while queue:
- op = queue.popleft()
- if op not in reached_ops:
- reached_ops.add(op)
- for output in op.outputs:
- if _IsBackpropagatable(output):
- queue.extend(_Consumers(output, func_graphs))
-
-
-def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs,
- xs):
- """Initialize the pending count for ops between two lists of Operations.
-
- 'pending_count[op]' indicates the number of backprop inputs
- to this operation.
-
- Args:
- to_ops: list of Operations.
- from_ops: list of Operations.
- colocate_gradients_with_ops: Python bool. See docstring of gradients().
- func_graphs: list of FuncGraphs. This method will traverse through
- these functions if they capture from_ops or any reachable ops. This is
- useful if to_ops occur in a function and from_ops are in an outer function
- or graph.
- xs: list of Tensors.
-
- Returns:
- A tuple containing: (1) the subset of to_ops reachable from from_ops by a
- path of zero or more backpropagatable tensors, (2) a mapping from operation
- to the number of backprop inputs to that op, and (3) a ControlFlowState
- object which is not None if the ops between from_ops and to_ops contain
- control flow loops.
- """
- # Mark reachable ops from from_ops.
- reached_ops = set()
- _MarkReachedOps(from_ops, reached_ops, func_graphs)
- # X in reached_ops iff X is reachable from from_ops by a path of zero or more
- # backpropagatable tensors.
-
- reachable_to_ops = set(op for op in to_ops if op in reached_ops)
-
- # Mark between ops.
- between_ops = set()
- between_op_list = []
- queue = collections.deque()
- queue.extend(to_ops)
- while queue:
- op = queue.popleft()
- # We are interested in this op.
- if op in reached_ops:
- between_ops.add(op)
- between_op_list.append(op)
- # Clear the boolean so we won't add the inputs again.
- reached_ops.remove(op)
- for inp in _NonEagerInputs(op, xs):
- queue.append(inp.op)
- # X in between_ops iff X is on a path of zero or more backpropagatable tensors
- # between from_ops and to_ops
-
- # 'loop_state' is None if there are no while loops.
- loop_state = control_flow_ops.MaybeCreateControlFlowState(
- between_op_list, between_ops, colocate_gradients_with_ops)
-
- # Initialize pending count for between ops.
- pending_count = collections.defaultdict(int)
- for op in between_op_list:
- for x in _NonEagerInputs(op, xs):
- if x.op in between_ops:
- pending_count[x.op] += 1
-
- return reachable_to_ops, pending_count, loop_state
-
-
-def _AsList(x):
- return x if isinstance(x, (list, tuple)) else [x]
-
-
-def _DefaultGradYs(grad_ys,
- ys,
- colocate_gradients_with_ops,
- gradient_uid="__unsupported__"):
- """Fill in default values for grad_ys.
-
- Args:
- grad_ys: List of gradients, can contain None.
- ys: List of tensors.
- colocate_gradients_with_ops: If True, try colocating gradients with
- the corresponding op.
- gradient_uid: A unique identifier within the graph indicating
- which invocation of gradients is being executed. Used to cluster
- ops for compilation.
-
- Returns:
- A list of gradients to use, without None.
-
- Raises:
- ValueError: If sizes of gradients and inputs don't match
- TypeError: If type of any gradient is not valid for its input.
- """
- if len(grad_ys) != len(ys):
- raise ValueError("Passed %d grad_ys for %d ys" % (len(grad_ys), len(ys)))
- grad_ys = ops.convert_n_to_tensor_or_indexed_slices(grad_ys, name="grad_y")
- new_grad_ys = []
- for i in xrange(len(grad_ys)):
- grad_y = grad_ys[i]
- y = ys[i]
- with _maybe_colocate_with(y.op, gradient_uid, colocate_gradients_with_ops):
- if grad_y is None:
- if y.dtype.is_complex:
- raise TypeError(
- "Gradients of complex tensors must set grad_ys (y.dtype = %r)" %
- y.dtype)
- new_grad_ys.append(
- array_ops.fill(
- array_ops.shape(y),
- constant_op.constant(1, dtype=y.dtype, name="grad_ys_%d" % i)))
- continue
- if y.dtype.is_floating or y.dtype.is_integer:
- if not grad_y.dtype.is_floating and not grad_y.dtype.is_integer:
- raise TypeError(
- "Gradient type %s generated for real or "
- "integer-valued tensor %s with type %s must be "
- "real or integer" % (dtypes.as_dtype(grad_y.dtype).name, y,
- dtypes.as_dtype(y.dtype).name))
- elif y.dtype.is_complex:
- if not grad_y.dtype.is_complex:
- raise TypeError(
- "Gradient type %s generated for complex-valued "
- "tensor %s with type %s must be real" % (dtypes.as_dtype(
- grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name))
- elif y.dtype == dtypes.variant:
- if grad_y.dtype != dtypes.variant:
- raise TypeError(
- "Gradient type %s generated for variant "
- "tensor %s with type %s must be variant" % (dtypes.as_dtype(
- grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name))
- elif y.dtype == dtypes.resource:
- # We assume y is the handle of a ResourceVariable. The gradient of a
- # ResourceVariable should be a numeric value, not another resource.
- if grad_y.dtype == dtypes.resource:
- raise TypeError("Input gradient %s for resource tensor %s should not "
- "be a resource" % (grad_y, y))
- else:
- raise TypeError(
- "Tensor %s with type %s must be numeric "
- "to obtain a default gradient" % (y, dtypes.as_dtype(y.dtype).name))
- # Create a grad_y tensor in the name scope of the gradient.
- # Required for TensorArrays to identify which gradient call a
- # grad_y value is coming from.
- if isinstance(grad_y, ops.IndexedSlices):
- new_grad_ys.append(
- ops.IndexedSlices(
- indices=(array_ops.identity(
- grad_y.indices, name="grad_ys_%d_indices" % i)
- if isinstance(grad_y.indices, ops.Tensor) else
- grad_y.indices),
- values=(array_ops.identity(
- grad_y.values, name="grad_ys_%d_values" % i) if isinstance(
- grad_y.values, ops.Tensor) else grad_y.values),
- dense_shape=(array_ops.identity(
- grad_y.dense_shape, name="grad_ys_%d_shape" % i)
- if isinstance(grad_y.dense_shape, ops.Tensor) else
- grad_y.dense_shape)))
- else:
- new_grad_ys.append(array_ops.identity(grad_y, name="grad_ys_%d" % i))
-
- return new_grad_ys
-
-
-def IsTrainable(tensor_or_dtype):
- if isinstance(tensor_or_dtype, ops.Tensor):
- dtype = tensor_or_dtype.dtype
- else:
- dtype = tensor_or_dtype
- dtype = dtypes.as_dtype(dtype)
- return dtype.base_dtype in (dtypes.float16, dtypes.float32, dtypes.float64,
- dtypes.complex64, dtypes.complex128,
- dtypes.resource, dtypes.variant)
-
-
-def _IsBackpropagatable(tensor):
- if IsTrainable(tensor):
- return True
- dtype = dtypes.as_dtype(tensor.dtype)
- return dtype.base_dtype == dtypes.bfloat16
-
-
-def _VerifyGeneratedGradients(grads, op):
- """Verify that gradients are valid in number and type.
-
- Args:
- grads: List of generated gradients.
- op: Operation for which the gradients where generated.
-
- Raises:
- ValueError: if sizes of gradients and inputs don't match.
- TypeError: if type of any gradient is not valid for its input.
- """
- # While ops have inputs added to them during the gradient computation, so we
- # skip the below check. See while_v2 for details.
- if op.type == "While": return
-
- if len(grads) != len(op.inputs):
- raise ValueError("Num gradients %d generated for op %s do not match num "
- "inputs %d" % (len(grads), op.node_def, len(op.inputs)))
-
-
-def _StopOps(from_ops, stop_gradient_ops, pending_count, xs):
- """The set of ops that terminate the gradient computation.
-
- This computes the frontier of the forward graph *before* which backprop
- should stop. Operations in the returned set will not be differentiated.
- This set is defined as the subset of `from_ops` containing ops that have
- no predecessor in `from_ops`. `pending_count` is the result of
- `_PendingCount(xs, from_ops)`. An 'op' has predecessors in `from_ops`
- iff pending_count[op] > 0.
-
- In addition, none of `stop_gradient_ops` will be differentiated.
-
- Args:
- from_ops: list of Operations.
- stop_gradient_ops: list of Operations never to backprop through.
- pending_count: mapping from operation to number of backprop inputs.
- xs: list of Tensors.
-
- Returns:
- The set of operations.
- """
- stop_ops = set()
- for op in from_ops:
- is_stop_op = True
- for inp in _NonEagerInputs(op, xs):
- if pending_count[inp.op] > 0:
- is_stop_op = False
- break
- if is_stop_op:
- stop_ops.add(op)
- stop_ops.update(op for op in stop_gradient_ops)
- return stop_ops
-
-
-@contextlib.contextmanager
-def _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops): # pylint: disable=invalid-name
- """Context to colocate with `op` if `colocate_gradients_with_ops`."""
- if colocate_gradients_with_ops:
- with ops._colocate_with_for_gradient(op, gradient_uid): # pylint: disable=protected-access
- yield
- else:
- yield
-
-
-def _IsPartitionedCall(op):
- return op.type == "PartitionedCall" or op.type == "StatefulPartitionedCall"
-
-
-def _SymGrad(op, out_grads):
- """Backprop through a function call node op given its outputs' gradients."""
- f_in = [x for x in op.inputs] + out_grads
- f_types = [x.dtype for x in op.inputs]
- f = attr_value_pb2.NameAttrList()
- if _IsPartitionedCall(op):
- f.name = op.get_attr("f").name
- else:
- f.name = op.type
- for k in op.node_def.attr:
- f.attr[k].CopyFrom(op.node_def.attr[k])
- # TODO(apassos) use a better dtype here
- in_grads = functional_ops.symbolic_gradient(
- input=f_in,
- Tout=[x if x != dtypes.resource else dtypes.float32 for x in f_types],
- f=f)
- return in_grads
-
-
-def _MaybeCompile(scope, op, func, grad_fn):
- """Compile the calculation in grad_fn if op was marked as compiled."""
- scope = scope.rstrip("/").replace("/", "_")
- if func is not None:
- xla_compile = func.definition.attr["_XlaCompile"].b
- xla_separate_compiled_gradients = func.definition.attr[
- "_XlaSeparateCompiledGradients"].b
- xla_scope = func.definition.attr["_XlaScope"].s.decode()
- else:
- try:
- xla_compile = op.get_attr("_XlaCompile")
- xla_separate_compiled_gradients = op.get_attr(
- "_XlaSeparateCompiledGradients")
- xla_scope = op.get_attr("_XlaScope").decode()
- except ValueError:
- return grad_fn() # Exit early
-
- if not xla_compile:
- return grad_fn() # Exit early
-
- # If the gradients are supposed to be compiled separately, we give them a
- # _XlaScope name that is based on the name_scope of the gradients. Otherwise
- # they just inherit the existing _XlaScope name, which lets them be merged
- # together with the non-gradient computation.
- if xla_separate_compiled_gradients:
- xla_grad_scope = "%s_grad_%s" % (xla_scope, scope)
- else:
- xla_grad_scope = xla_scope
-
- attrs = {
- "_XlaCompile": attr_value_pb2.AttrValue(b=xla_compile),
- "_XlaScope": attr_value_pb2.AttrValue(s=xla_grad_scope.encode())
- }
- with ops.get_default_graph()._attr_scope(attrs): # pylint: disable=protected-access
- return grad_fn()
-
-
-def _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs):
- """Raises an error if we backprop through a loop var."""
- # Find the nearest 'to_op' reachable from 'op' to provide a more helpful error
- # message.
- target_op = None
- queue = collections.deque([op])
- visited = set()
- while queue:
- curr_op = queue.popleft()
- if curr_op in visited: continue
- visited.add(curr_op)
- if curr_op in from_ops:
- target_op = curr_op
- break
- queue.extend(t.op for t in _NonEagerInputs(curr_op, xs))
- assert target_op
- raise ValueError(
- "Cannot compute gradient inside while loop with respect to op '%s'. "
- "We do not support taking the gradient wrt or through the initial value "
- "of a loop variable. Gradients can be computed through loop invariants "
- "or wrt the input parameters to the loop body."
- % target_op.name)
-
-
-def _IsFunction(graph):
- return (isinstance(graph, FuncGraph) or
- isinstance(graph, framework_function._FuncGraph)) # pylint: disable=protected-access
-
-
-def _Captures(func_graph):
- if isinstance(func_graph, FuncGraph):
- return func_graph.captures
- else:
- assert isinstance(func_graph, framework_function._FuncGraph) # pylint: disable=protected-access
- return func_graph._captured # pylint: disable=protected-access
-
-
-def _MaybeCaptured(t):
- """If t is a captured value placeholder, returns the original captured value.
-
- Args:
- t: Tensor
-
- Returns:
- A tensor, potentially from a different Graph/FuncGraph.
- """
- # pylint: disable=protected-access
- if (not isinstance(t, ops.EagerTensor) and
- _IsFunction(t.op.graph) and t.op.type == "Placeholder"):
- for input_t, placeholder_t in _Captures(t.op.graph).items():
- if t == placeholder_t:
- return _MaybeCaptured(input_t)
- # pylint: enable=protected-access
- return t
-
-
-# TODO(skyewm): plumbing xs through everywhere is ugly, consider making
-# _GradientsHelper a class with xs as a member variable.
-def _NonEagerInputs(op, xs):
- """Returns the inputs of op, crossing closure boundaries where necessary.
-
- Does not return any captured EagerTensors, i.e., the number of tensors
- returned may be less than than the actual number of inputs.
-
- Args:
- op: Operation
- xs: list of Tensors we are differentiating w.r.t.
-
- Returns:
- A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op
- is in a FuncGraph and has captured inputs.
- """
- if _IsFunction(op.graph): # pylint: disable=protected-access
- inputs = []
- for t in op.inputs:
- # If we're differentiating w.r.t. `t`, do not attempt to traverse through
- # it to a captured value. The algorithm needs to "see" `t` in this case,
- # even if it's a function input for a captured value, whereas usually we'd
- # like to traverse through these closures as if the captured value was the
- # direct input to op.
- if t not in xs:
- t = _MaybeCaptured(t)
- # Skip captured eager inputs.
- if isinstance(t, ops.EagerTensor): continue
- inputs.append(t)
- return inputs
- else:
- return op.inputs
-
-
-def _Consumers(t, func_graphs):
- """Returns the consumers of t, crossing closure boundaries where necessary.
-
- Args:
- t: Tensor
- func_graphs: a list of FuncGraphs that may have captured t.
-
- Returns:
- A list of tensors. The tensors will be from the current graph and/or
- func_graphs.
- """
- consumers = t.consumers()
- for func in func_graphs:
- for input_t, placeholder in _Captures(func).items():
- if input_t == t:
- consumers.extend(_Consumers(placeholder, func_graphs))
- return consumers
-
-
@tf_export(v1=["gradients"])
def gradients(ys,
xs,
@@ -658,10 +150,13 @@
# Creating the gradient graph for control flow mutates Operations.
# _mutation_lock ensures a Session.run call cannot occur between creating and
# mutating new ops.
- with ops.get_default_graph()._mutation_lock(): # pylint: disable=protected-access
- return _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
- gate_gradients, aggregation_method, stop_gradients,
- unconnected_gradients)
+ # pylint: disable=protected-access
+ with ops.get_default_graph()._mutation_lock():
+ return gradients_util._GradientsHelper(
+ ys, xs, grad_ys, name, colocate_gradients_with_ops,
+ gate_gradients, aggregation_method, stop_gradients,
+ unconnected_gradients)
+ # pylint: enable=protected-access
@tf_export("gradients", v1=[])
@@ -771,540 +266,13 @@
# Creating the gradient graph for control flow mutates Operations.
# _mutation_lock ensures a Session.run call cannot occur between creating and
# mutating new ops.
- with ops.get_default_graph()._mutation_lock(): # pylint: disable=protected-access
- return _GradientsHelper(ys, xs, grad_ys, name, True, gate_gradients,
- aggregation_method, stop_gradients,
- unconnected_gradients)
-
-
-def _GradientsHelper(ys,
- xs,
- grad_ys=None,
- name="gradients",
- colocate_gradients_with_ops=False,
- gate_gradients=False,
- aggregation_method=None,
- stop_gradients=None,
- unconnected_gradients=UnconnectedGradients.NONE,
- src_graph=None):
- """Implementation of gradients()."""
- if context.executing_eagerly():
- raise RuntimeError("tf.gradients is not supported when eager execution "
- "is enabled. Use tf.GradientTape instead.")
- if src_graph is None:
- src_graph = ops.get_default_graph()
- try:
- unconnected_gradients = UnconnectedGradients(unconnected_gradients)
- except ValueError:
- raise ValueError(
- "Unknown value for unconnected_gradients: %r" % unconnected_gradients)
-
- # If src_graph is a _FuncGraph (i.e. a function body), gather it and all
- # ancestor graphs. This is necessary for correctly handling captured values.
- func_graphs = []
- curr_graph = src_graph
- while _IsFunction(curr_graph):
- func_graphs.append(curr_graph)
- if isinstance(curr_graph, FuncGraph):
- curr_graph = curr_graph.outer_graph
- else:
- assert isinstance(curr_graph, framework_function._FuncGraph) # pylint: disable=protected-access
- curr_graph = curr_graph._outer_graph # pylint: disable=protected-access
-
- ys = _AsList(ys)
- xs = _AsList(xs)
- stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients)
- if grad_ys is None:
- grad_ys = [None] * len(ys)
- else:
- grad_ys = _AsList(grad_ys)
-
- with ops.name_scope(
- name, "gradients",
- list(ys) + list(xs) + list(stop_gradients) + list(grad_ys)) as grad_scope:
- # Get a uid for this call to gradients that can be used to help
- # cluster ops for compilation.
- gradient_uid = ops.get_default_graph().unique_name("uid")
- ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y")
- xs = [
- x.handle if resource_variable_ops.is_resource_variable(x) else x
- for x in xs
- ]
- xs = ops.internal_convert_n_to_tensor_or_indexed_slices(
- xs, name="x", as_ref=True)
- grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops,
- gradient_uid)
-
- # The approach we take here is as follows: Create a list of all ops in the
- # subgraph between the ys and xs. Visit these ops in reverse order of ids
- # to ensure that when we visit an op the gradients w.r.t its outputs have
- # been collected. Then aggregate these gradients if needed, call the op's
- # gradient function, and add the generated gradients to the gradients for
- # its input.
-
- # Initialize the pending count for ops in the connected subgraph from ys
- # to the xs.
- to_ops = [t.op for t in ys]
- from_ops = [t.op for t in xs]
- stop_gradient_ops = [t.op for t in stop_gradients]
- reachable_to_ops, pending_count, loop_state = _PendingCount(
- to_ops, from_ops, colocate_gradients_with_ops, func_graphs, xs)
-
- # Iterate over the collected ops.
- #
- # grads: op => list of gradients received on each output endpoint of the
- # op. The gradients for each endpoint are initially collected as a list.
- # When it is time to call the op's gradient function, for each endpoint we
- # aggregate the list of received gradients into a Add() Operation if there
- # is more than one.
- grads = {}
-
- # Add the initial gradients for the ys.
- for y, grad_y in zip(ys, grad_ys):
- _SetGrad(grads, y, grad_y)
-
- # Initialize queue with to_ops.
- queue = collections.deque()
- # Add the ops in 'to_ops' into the queue.
- to_ops_set = set()
- for op in to_ops:
- # 'ready' handles the case where one output gradient relies on
- # another output's gradient.
- ready = (pending_count[op] == 0)
- if ready and op not in to_ops_set and op in reachable_to_ops:
- to_ops_set.add(op)
- queue.append(op)
-
- if loop_state:
- loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set)
- for y in loop_exits:
- if IsTrainable(y):
- _SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
- queue.append(y.op)
-
- stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs)
- while queue:
- # generate gradient subgraph for op.
- op = queue.popleft()
- with _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops):
- if loop_state:
- loop_state.EnterGradWhileContext(op, before=True)
- out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state,
- aggregation_method)
- if loop_state:
- loop_state.ExitGradWhileContext(op, before=True)
-
- grad_fn = None
- func_call = None
- is_partitioned_call = _IsPartitionedCall(op)
- # pylint: disable=protected-access
- is_func_call = (
- src_graph._is_function(op.type) or is_partitioned_call)
- # pylint: enable=protected-access
- has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads)
- if has_out_grads and (op not in stop_ops):
- try:
- grad_fn = ops.get_gradient_function(op)
- except LookupError:
- if is_func_call:
- if is_partitioned_call:
- func_call = src_graph._get_function( # pylint: disable=protected-access
- compat.as_bytes(op.get_attr("f").name))
- else:
- func_call = src_graph._get_function(op.type) # pylint: disable=protected-access
- # Note that __defun is not set if the graph is
- # imported. If it's set, we prefer to access the original
- # defun.
- func_call = getattr(op, "__defun", func_call)
- grad_fn = func_call.python_grad_func
- else:
- raise LookupError(
- "No gradient defined for operation '%s' (op type: %s)" %
- (op.name, op.type))
- if loop_state:
- loop_state.EnterGradWhileContext(op, before=False)
-
- # NOTE(skyewm): We don't support computing gradients wrt a loop variable
- # unless it's within the context of a single iteration (i.e. the
- # gradient is wrt to the loop parameter in the body function, not wrt or
- # through the initial value). This means if we're in a while loop
- # context, we should never see a switch node from this context.
- # pylint: disable=protected-access
- if (control_flow_util.IsSwitch(op) and
- op._control_flow_context is not None and
- op._control_flow_context.IsWhileContext() and
- op._control_flow_context ==
- ops.get_default_graph()._get_control_flow_context()):
- _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs)
- # pylint: enable=protected-access
-
- if (grad_fn or is_func_call) and has_out_grads:
- # NOTE: If _AggregatedGrads didn't compute a value for the i'th
- # output, it means that the cost does not depend on output[i],
- # therefore dC/doutput[i] is 0.
- for i, out_grad in enumerate(out_grads):
- if (not isinstance(out_grad, ops.Tensor) and not out_grad) and (
- (not grad_fn and is_func_call) or IsTrainable(op.outputs[i])):
- # Only trainable outputs or outputs for a function call that
- # will use SymbolicGradient get a zero gradient. Gradient
- # functions should ignore the gradient for other outputs.
- # TODO(apassos) gradients of resource handles might be an
- # issue here because of zeros.
- if loop_state:
- out_grads[i] = loop_state.ZerosLike(op, i)
- else:
- out_grads[i] = control_flow_ops.ZerosLikeOutsideLoop(op, i)
- with ops.name_scope(op.name + "_grad"):
- # pylint: disable=protected-access
- with src_graph._original_op(op):
- # pylint: enable=protected-access
- if grad_fn:
- # If grad_fn was found, do not use SymbolicGradient even for
- # functions.
- in_grads = _MaybeCompile(grad_scope, op, func_call,
- lambda: grad_fn(op, *out_grads))
- else:
- # For function call ops, we add a 'SymbolicGradient'
- # node to the graph to compute gradients.
- in_grads = _MaybeCompile(grad_scope, op, func_call,
- lambda: _SymGrad(op, out_grads))
- in_grads = _AsList(in_grads)
- _VerifyGeneratedGradients(in_grads, op)
- if gate_gradients and len([x for x in in_grads
- if x is not None]) > 1:
- with ops.device(None):
- with ops._colocate_with_for_gradient( # pylint: disable=protected-access
- None,
- gradient_uid,
- ignore_existing=True):
- in_grads = control_flow_ops.tuple(in_grads)
- _LogOpGradients(op, out_grads, in_grads)
- else:
- # If no grad_fn is defined or none of out_grads is available,
- # just propagate a list of None backwards.
- in_grads = [None] * len(_NonEagerInputs(op, xs))
- for i, (t_in, in_grad) in enumerate(zip(_NonEagerInputs(op, xs),
- in_grads)):
- if in_grad is not None:
- if (isinstance(in_grad, ops.Tensor) and
- t_in.dtype != dtypes.resource):
- try:
- in_grad.set_shape(t_in.get_shape())
- except ValueError:
- raise ValueError(
- "Incompatible shapes between op input and calculated "
- "input gradient. Forward operation: %s. Input index: %d. "
- "Original input shape: %s. "
- "Calculated input gradient shape: %s" %
- (op.name, i, t_in.shape, in_grad.shape))
- _SetGrad(grads, t_in, in_grad)
- if loop_state:
- loop_state.ExitGradWhileContext(op, before=False)
-
- # Update pending count for the inputs of op and enqueue ready ops.
- _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
- xs)
-
- if loop_state:
- loop_state.PostProcessing()
- return [_GetGrad(grads, x, unconnected_gradients) for x in xs]
-
-
-def _HasAnyNotNoneGrads(grads, op):
- """Return true iff op has real gradient."""
- out_grads = _GetGrads(grads, op)
- for out_grad in out_grads:
- if isinstance(out_grad, (ops.Tensor, ops.IndexedSlices)):
- return True
- if out_grad and isinstance(out_grad, collections.Sequence):
- if any(g is not None for g in out_grad):
- return True
- return False
-
-
-def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
- xs):
- """Update pending count for the inputs of op and enqueue ready ops."""
- for x in _NonEagerInputs(op, xs):
- pending_count[x.op] -= 1
- ready = (pending_count[x.op] == 0)
- if loop_state and not ready:
- ready = pending_count[x.op] > 0 and control_flow_util.IsLoopSwitch(x.op)
- if ready:
- if control_flow_util.IsLoopExit(x.op):
- # if x is an exit without real gradient, defer processing them.
- grad_state = loop_state.GetGradState(x.op, before=False)
- grad_state.deferred_exits.append(x)
- grad_state.pending_exits_count -= 1
- if grad_state.pending_exits_count == 0:
- # We now have all the exits so process them.
- has_not_none_grad = False
- for y in grad_state.deferred_exits:
- if _HasAnyNotNoneGrads(grads, y.op):
- has_not_none_grad = True
- queue.append(y.op)
- else:
- grad_state.unused_exits.append(y)
- if has_not_none_grad:
- # For an unused exit, if it has trainable outputs, backprop
- # a zero gradient. Otherwise, just ignore it.
- for y in grad_state.unused_exits:
- if IsTrainable(y):
- _SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
- queue.append(y.op)
- else:
- # All exits are "unused" so use None as gradient.
- for y in grad_state.unused_exits:
- queue.append(y.op)
- else:
- queue.append(x.op)
-
-
-def _SetGrad(grads, t, grad):
- """Sets gradient "grad" in "grads" for tensor "t"."""
- op = t.op
- op_grads = grads.get(op)
- if not op_grads:
- op_grads = [[] for _ in xrange(len(op.outputs))]
- grads[op] = op_grads
- t_grads = op_grads[t.value_index]
- if isinstance(t_grads, list):
- t_grads.append(grad)
- else:
- assert control_flow_util.IsLoopSwitch(op)
- op_grads[t.value_index] = grad
-
-
-def _GetGrad(grads, t, unconnected_gradients):
- """Gets gradient for tensor "t"."""
- op = t.op
- op_grads = grads.get(op)
- if not op_grads:
- if unconnected_gradients == UnconnectedGradients.ZERO:
- t_dtype = t.dtype if t.dtype != dtypes.resource else dtypes.float32
- return array_ops.zeros_like(t, dtype=t_dtype)
- elif unconnected_gradients == UnconnectedGradients.NONE:
- return None
- else:
- raise ValueError(
- "Unknown value for unconnected_gradients: %r" % unconnected_gradients)
-
- t_grad = op_grads[t.value_index]
- assert not isinstance(
- t_grad, list), ("gradients list should have been aggregated by now.")
- return t_grad
-
-
-def _GetGrads(grads, op):
- """Gets all gradients for op."""
- if op in grads:
- return grads[op]
- else:
- return [[] for _ in xrange(len(op.outputs))]
-
-
-def _HandleNestedIndexedSlices(grad):
- assert isinstance(grad, ops.IndexedSlices)
- if isinstance(grad.values, ops.Tensor):
- return grad
- else:
- assert isinstance(grad.values, ops.IndexedSlices)
- g = _HandleNestedIndexedSlices(grad.values)
- return ops.IndexedSlices(g.values, array_ops.gather(
- grad.indices, g.indices), g.dense_shape)
-
-
-def _AccumulatorShape(inputs):
- shape = tensor_shape.unknown_shape()
- for i in inputs:
- if isinstance(i, ops.Tensor):
- shape = shape.merge_with(i.get_shape())
- return shape
-
-
-def _LogOpGradients(op, out_grads, in_grads):
- """Log the in and out grads of an op."""
- logging.vlog(1, "Gradient for '" + op.name + "'")
-
- def _FilterGrad(x):
- if x is None:
- return False
- if isinstance(x, (list, tuple)):
- return bool(x)
- else:
- return True
-
- logging.vlog(1, " in --> %s",
- ", ".join([x.name for x in out_grads if _FilterGrad(x)]))
- logging.vlog(1, " out --> %s",
- ", ".join([x.name for x in in_grads if _FilterGrad(x)]))
-
-
-def _MultiDeviceAddN(tensor_list, gradient_uid):
- """Adds tensors from potentially multiple devices."""
- # Basic function structure comes from control_flow_ops.group().
- # Sort tensors according to their devices.
- tensors_on_device = collections.defaultdict(lambda: [])
- for tensor in tensor_list:
- tensors_on_device[tensor.device].append(tensor)
-
- # For each device, add the tensors on that device first.
- # Then gather the partial sums from multiple devices.
- # TODO(sjhwang): Create hierarchical aggregation tree as pbar's suggestion.
- # E.g., aggregate per GPU, then per task, and so on.
- summands = []
-
- def DeviceKey(dev):
- return "" if dev is None else dev
-
- for dev in sorted(six.iterkeys(tensors_on_device), key=DeviceKey):
- tensors = tensors_on_device[dev]
- with ops._colocate_with_for_gradient( # pylint: disable=protected-access
- tensors[0].op,
- gradient_uid,
- ignore_existing=True):
- summands.append(math_ops.add_n(tensors))
-
- return math_ops.add_n(summands)
-
-
-@tf_export("AggregationMethod")
-class AggregationMethod(object):
- """A class listing aggregation methods used to combine gradients.
-
- Computing partial derivatives can require aggregating gradient
- contributions. This class lists the various methods that can
- be used to combine gradients in the graph:
-
- * `ADD_N`: All of the gradient terms are summed as part of one
- operation using the "AddN" op. It has the property that all
- gradients must be ready before any aggregation is performed.
- * `DEFAULT`: The system-chosen default aggregation method.
- """
- ADD_N = 0
- DEFAULT = ADD_N
- # The following are experimental and may not be supported in future releases.
- EXPERIMENTAL_TREE = 1
- EXPERIMENTAL_ACCUMULATE_N = 2
-
-
-def _AggregatedGrads(grads,
- op,
- gradient_uid,
- loop_state,
- aggregation_method=None):
- """Get the aggregated gradients for op.
-
- Args:
- grads: The map of memoized gradients.
- op: The op to get gradients for.
- gradient_uid: A unique identifier within the graph indicating
- which invocation of gradients is being executed. Used to cluster
- ops for compilation.
- loop_state: An object for maintaining the state of the while loops in the
- graph. It is of type ControlFlowState. None if the graph
- contains no while loops.
- aggregation_method: Specifies the method used to combine gradient terms.
- Accepted values are constants defined in the class `AggregationMethod`.
-
- Returns:
- A list of gradients, one per each output of `op`. If the gradients
- for a particular output is a list, this function aggregates it
- before returning.
-
- Raises:
- TypeError: if the incoming grads are not Tensors or IndexedSlices.
- ValueError: if the arguments are invalid.
-
- """
- if aggregation_method is None:
- aggregation_method = AggregationMethod.DEFAULT
- if aggregation_method not in [
- AggregationMethod.ADD_N, AggregationMethod.EXPERIMENTAL_TREE,
- AggregationMethod.EXPERIMENTAL_ACCUMULATE_N
- ]:
- raise ValueError(
- "Invalid aggregation_method specified %s." % aggregation_method)
- out_grads = _GetGrads(grads, op)
- for i, out_grad in enumerate(out_grads):
- if loop_state:
- if isinstance(out_grad, (ops.Tensor, ops.IndexedSlices)):
- assert control_flow_util.IsLoopSwitch(op)
- continue
- # Grads have to be Tensors or IndexedSlices
- if (isinstance(out_grad, collections.Sequence) and not all(
- isinstance(g, (ops.Tensor, ops.IndexedSlices))
- for g in out_grad
- if g is not None
- )):
- raise TypeError("gradients have to be either all Tensors "
- "or all IndexedSlices")
- # Aggregate multiple gradients, and convert [] to None.
- if out_grad:
- if len(out_grad) < 2:
- used = "nop"
- out_grads[i] = out_grad[0]
- elif all(isinstance(g, ops.Tensor) for g in out_grad if g is not None):
- tensor_shape = _AccumulatorShape(out_grad)
- if (aggregation_method == AggregationMethod.EXPERIMENTAL_ACCUMULATE_N
- and len(out_grad) > 2 and tensor_shape.is_fully_defined()):
- # The benefit of using AccumulateN is that its inputs can be combined
- # in any order and this can allow the expression to be evaluated with
- # a smaller memory footprint. When used with gpu_allocator_retry,
- # it is possible to compute a sum of terms which are much larger than
- # total GPU memory.
- # AccumulateN can currently only be used if we know the shape for
- # an accumulator variable. If this is not known, or if we only have
- # 2 grads then we fall through to the "tree" case below.
- used = "accumulate_n"
- out_grads[i] = math_ops.accumulate_n(out_grad)
- elif aggregation_method in [
- AggregationMethod.EXPERIMENTAL_TREE,
- AggregationMethod.EXPERIMENTAL_ACCUMULATE_N
- ]:
- # Aggregate all gradients by doing pairwise sums: this may
- # reduce performance, but it can improve memory because the
- # gradients can be released earlier.
- #
- # TODO(vrv): Consider replacing this with a version of
- # tf.AddN() that eagerly frees its inputs as soon as they are
- # ready, so the order of this tree does not become a problem.
- used = "tree"
- with ops.name_scope(op.name + "_gradient_sum"):
- running_sum = out_grad[0]
- for grad in out_grad[1:]:
- running_sum = math_ops.add_n([running_sum, grad])
- out_grads[i] = running_sum
- else:
- used = "add_n"
- out_grads[i] = _MultiDeviceAddN(out_grad, gradient_uid)
- logging.vlog(2, " _AggregatedGrads %d x %s using %s", len(out_grad),
- tensor_shape, used)
- else:
- out_grads[i] = _AggregateIndexedSlicesGradients(out_grad)
- else: # not out_grad
- # out_grads[i] is [], thus its aggregation is simply None.
- out_grads[i] = None
- return out_grads
-
-
-def _AggregateIndexedSlicesGradients(grads):
- """Aggregates gradients of type `IndexedSlices` by concatenation."""
- if len(grads) < 1:
- return None
- elif len(grads) == 1:
- return grads[0]
- else:
- grads = math_ops._as_indexed_slices_list( # pylint: disable=protected-access
- [g for g in grads if g is not None])
- grads = [_HandleNestedIndexedSlices(x) for x in grads] # pylint: disable=protected-access
- # Form IndexedSlices out of the concatenated values and indices.
- concat_grad = ops.IndexedSlices(
- array_ops.concat([x.values for x in grads], axis=0),
- array_ops.concat([x.indices for x in grads], axis=0),
- grads[0].dense_shape)
-
- return concat_grad
+ # pylint: disable=protected-access
+ with ops.get_default_graph()._mutation_lock():
+ return gradients_util._GradientsHelper(
+ ys, xs, grad_ys, name, True, gate_gradients,
+ aggregation_method, stop_gradients,
+ unconnected_gradients)
+ # pylint: enable=protected-access
# TODO(vrv): Make this available when we want to make it public.
@@ -1393,7 +361,7 @@
LookupError: if one of the operations between `xs` and `ys` does not
have a registered gradient function.
"""
- xs = _AsList(xs)
+ xs = gradients_util._AsList(xs) # pylint: disable=protected-access
kwargs = {
"colocate_gradients_with_ops": colocate_gradients_with_ops,
"gate_gradients": gate_gradients,
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py
index c53afef..9caffa3 100644
--- a/tensorflow/python/ops/gradients_test.py
+++ b/tensorflow/python/ops/gradients_test.py
@@ -45,6 +45,7 @@
from tensorflow.python.ops import functional_ops # pylint: disable=unused-import
from tensorflow.python.ops import gradients
from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import gradients_util
from tensorflow.python.ops import list_ops
from tensorflow.python.ops import math_grad # pylint: disable=unused-import
from tensorflow.python.ops import math_ops
@@ -1040,12 +1041,12 @@
self.evaluate(ops.convert_to_tensor(right)))
def testNoGradients(self):
- self.assertIsNone(gradients_impl._AggregateIndexedSlicesGradients([]))
+ self.assertIsNone(gradients_util._AggregateIndexedSlicesGradients([]))
def testOneGradient(self):
t = math_ops._as_indexed_slices(constant_op.constant(
[[1., 2.], [0, 0], [3., 4.]]))
- result = gradients_impl._AggregateIndexedSlicesGradients([t])
+ result = gradients_util._AggregateIndexedSlicesGradients([t])
self._assert_indexed_slices_equal(t, result)
def testMultipleGradients(self):
@@ -1055,7 +1056,7 @@
[[0., 0.], [5, 6], [7., 8.]]))
total = constant_op.constant(
[[1., 2.], [5, 6], [10., 12.]])
- result = gradients_impl._AggregateIndexedSlicesGradients([t0, t1])
+ result = gradients_util._AggregateIndexedSlicesGradients([t0, t1])
self._assert_indexed_slices_equal(total, result)
def testMultipleGradientsWithNones(self):
@@ -1066,7 +1067,7 @@
t3 = None
total = constant_op.constant(
[[1., 2.], [5, 6], [10., 12.]])
- result = gradients_impl._AggregateIndexedSlicesGradients([t0, t1, t3])
+ result = gradients_util._AggregateIndexedSlicesGradients([t0, t1, t3])
self._assert_indexed_slices_equal(total, result)
def testMixedTensorAndIndexedSlices(self):
@@ -1076,7 +1077,7 @@
[[0., 0.], [5, 6], [7., 8.]])
total = constant_op.constant(
[[1., 2.], [5, 6], [10., 12.]])
- result = gradients_impl._AggregateIndexedSlicesGradients([t0, t1])
+ result = gradients_util._AggregateIndexedSlicesGradients([t0, t1])
self._assert_indexed_slices_equal(total, result)
diff --git a/tensorflow/python/ops/gradients_util.py b/tensorflow/python/ops/gradients_util.py
new file mode 100644
index 0000000..4f73423
--- /dev/null
+++ b/tensorflow/python/ops/gradients_util.py
@@ -0,0 +1,1062 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Implements the graph generation for computation of gradients."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import contextlib
+import warnings
+
+import numpy as np
+import six
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+from tensorflow.core.framework import attr_value_pb2
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function as framework_function
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.framework.func_graph import FuncGraph
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import control_flow_util
+from tensorflow.python.ops import functional_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops.unconnected_gradients import UnconnectedGradients
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import compat
+from tensorflow.python.util.tf_export import tf_export
+
+
+# Warn the user if we convert a sparse representation to dense with at
+# least this number of elements.
+_LARGE_SPARSE_NUM_ELEMENTS = 100000000
+
+
+def _IndexedSlicesToTensor(value, dtype=None, name=None, as_ref=False):
+ """Converts an IndexedSlices object `value` to a Tensor.
+
+ NOTE(mrry): This function is potentially expensive.
+
+ Args:
+ value: An ops.IndexedSlices object.
+ dtype: The dtype of the Tensor to be returned.
+ name: Optional name to use for the returned Tensor.
+ as_ref: True if a ref is requested.
+
+ Returns:
+ A dense Tensor representing the values in the given IndexedSlices.
+
+ Raises:
+ ValueError: If the IndexedSlices does not have the same dtype.
+ """
+ _ = as_ref
+ if dtype and not dtype.is_compatible_with(value.dtype):
+ raise ValueError(
+ "Tensor conversion requested dtype %s for IndexedSlices with dtype %s" %
+ (dtype.name, value.dtype.name))
+ if value.dense_shape is None:
+ raise ValueError(
+ "Tensor conversion requested for IndexedSlices without dense_shape: %s"
+ % str(value))
+ # TODO(mrry): Consider adding static shape information to
+ # IndexedSlices, to avoid using numpy here.
+ if not context.executing_eagerly():
+ dense_shape_value = tensor_util.constant_value(value.dense_shape)
+ if dense_shape_value is not None:
+ num_elements = np.prod(dense_shape_value)
+ if num_elements >= _LARGE_SPARSE_NUM_ELEMENTS:
+ warnings.warn(
+ "Converting sparse IndexedSlices to a dense Tensor with %d "
+ "elements. This may consume a large amount of memory." %
+ num_elements)
+ else:
+ warnings.warn(
+ "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
+ "This may consume a large amount of memory.")
+ return math_ops.unsorted_segment_sum(
+ value.values, value.indices, value.dense_shape[0], name=name)
+
+
+ops.register_tensor_conversion_function(ops.IndexedSlices,
+ _IndexedSlicesToTensor)
+
+
+def _MarkReachedOps(from_ops, reached_ops, func_graphs):
+ """Mark all ops reached from "from_ops".
+
+ Args:
+ from_ops: list of Operations.
+ reached_ops: set of Operations.
+ func_graphs: list of FuncGraphs. This method will traverse through
+ these functions if they capture from_ops or any reachable ops.
+ """
+ queue = collections.deque()
+ queue.extend(from_ops)
+ while queue:
+ op = queue.popleft()
+ if op not in reached_ops:
+ reached_ops.add(op)
+ for output in op.outputs:
+ if _IsBackpropagatable(output):
+ queue.extend(_Consumers(output, func_graphs))
+
+
+def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs,
+ xs):
+ """Initialize the pending count for ops between two lists of Operations.
+
+ 'pending_count[op]' indicates the number of backprop inputs
+ to this operation.
+
+ Args:
+ to_ops: list of Operations.
+ from_ops: list of Operations.
+ colocate_gradients_with_ops: Python bool. See docstring of gradients().
+ func_graphs: list of FuncGraphs. This method will traverse through
+ these functions if they capture from_ops or any reachable ops. This is
+ useful if to_ops occur in a function and from_ops are in an outer function
+ or graph.
+ xs: list of Tensors.
+
+ Returns:
+ A tuple containing: (1) the subset of to_ops reachable from from_ops by a
+ path of zero or more backpropagatable tensors, (2) a mapping from operation
+ to the number of backprop inputs to that op, and (3) a ControlFlowState
+ object which is not None if the ops between from_ops and to_ops contain
+ control flow loops.
+ """
+ # Mark reachable ops from from_ops.
+ reached_ops = set()
+ _MarkReachedOps(from_ops, reached_ops, func_graphs)
+ # X in reached_ops iff X is reachable from from_ops by a path of zero or more
+ # backpropagatable tensors.
+
+ reachable_to_ops = set(op for op in to_ops if op in reached_ops)
+
+ # Mark between ops.
+ between_ops = set()
+ between_op_list = []
+ queue = collections.deque()
+ queue.extend(to_ops)
+ while queue:
+ op = queue.popleft()
+ # We are interested in this op.
+ if op in reached_ops:
+ between_ops.add(op)
+ between_op_list.append(op)
+ # Clear the boolean so we won't add the inputs again.
+ reached_ops.remove(op)
+ for inp in _NonEagerInputs(op, xs):
+ queue.append(inp.op)
+ # X in between_ops iff X is on a path of zero or more backpropagatable tensors
+ # between from_ops and to_ops
+
+ # 'loop_state' is None if there are no while loops.
+ loop_state = control_flow_ops.MaybeCreateControlFlowState(
+ between_op_list, between_ops, colocate_gradients_with_ops)
+
+ # Initialize pending count for between ops.
+ pending_count = collections.defaultdict(int)
+ for op in between_op_list:
+ for x in _NonEagerInputs(op, xs):
+ if x.op in between_ops:
+ pending_count[x.op] += 1
+
+ return reachable_to_ops, pending_count, loop_state
+
+
+def _AsList(x):
+ return x if isinstance(x, (list, tuple)) else [x]
+
+
+def _DefaultGradYs(grad_ys,
+ ys,
+ colocate_gradients_with_ops,
+ gradient_uid="__unsupported__"):
+ """Fill in default values for grad_ys.
+
+ Args:
+ grad_ys: List of gradients, can contain None.
+ ys: List of tensors.
+ colocate_gradients_with_ops: If True, try colocating gradients with
+ the corresponding op.
+ gradient_uid: A unique identifier within the graph indicating
+ which invocation of gradients is being executed. Used to cluster
+ ops for compilation.
+
+ Returns:
+ A list of gradients to use, without None.
+
+ Raises:
+ ValueError: If sizes of gradients and inputs don't match
+ TypeError: If type of any gradient is not valid for its input.
+ """
+ if len(grad_ys) != len(ys):
+ raise ValueError("Passed %d grad_ys for %d ys" % (len(grad_ys), len(ys)))
+ grad_ys = ops.convert_n_to_tensor_or_indexed_slices(grad_ys, name="grad_y")
+ new_grad_ys = []
+ for i in xrange(len(grad_ys)):
+ grad_y = grad_ys[i]
+ y = ys[i]
+ with _maybe_colocate_with(y.op, gradient_uid, colocate_gradients_with_ops):
+ if grad_y is None:
+ if y.dtype.is_complex:
+ raise TypeError(
+ "Gradients of complex tensors must set grad_ys (y.dtype = %r)" %
+ y.dtype)
+ new_grad_ys.append(
+ array_ops.fill(
+ array_ops.shape(y),
+ constant_op.constant(1, dtype=y.dtype, name="grad_ys_%d" % i)))
+ continue
+ if y.dtype.is_floating or y.dtype.is_integer:
+ if not grad_y.dtype.is_floating and not grad_y.dtype.is_integer:
+ raise TypeError(
+ "Gradient type %s generated for real or "
+ "integer-valued tensor %s with type %s must be "
+ "real or integer" % (dtypes.as_dtype(grad_y.dtype).name, y,
+ dtypes.as_dtype(y.dtype).name))
+ elif y.dtype.is_complex:
+ if not grad_y.dtype.is_complex:
+ raise TypeError(
+ "Gradient type %s generated for complex-valued "
+ "tensor %s with type %s must be real" % (dtypes.as_dtype(
+ grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name))
+ elif y.dtype == dtypes.variant:
+ if grad_y.dtype != dtypes.variant:
+ raise TypeError(
+ "Gradient type %s generated for variant "
+ "tensor %s with type %s must be variant" % (dtypes.as_dtype(
+ grad_y.dtype).name, y, dtypes.as_dtype(y.dtype).name))
+ elif y.dtype == dtypes.resource:
+ # We assume y is the handle of a ResourceVariable. The gradient of a
+ # ResourceVariable should be a numeric value, not another resource.
+ if grad_y.dtype == dtypes.resource:
+ raise TypeError("Input gradient %s for resource tensor %s should not "
+ "be a resource" % (grad_y, y))
+ else:
+ raise TypeError(
+ "Tensor %s with type %s must be numeric "
+ "to obtain a default gradient" % (y, dtypes.as_dtype(y.dtype).name))
+ # Create a grad_y tensor in the name scope of the gradient.
+ # Required for TensorArrays to identify which gradient call a
+ # grad_y value is coming from.
+ if isinstance(grad_y, ops.IndexedSlices):
+ new_grad_ys.append(
+ ops.IndexedSlices(
+ indices=(array_ops.identity(
+ grad_y.indices, name="grad_ys_%d_indices" % i)
+ if isinstance(grad_y.indices, ops.Tensor) else
+ grad_y.indices),
+ values=(array_ops.identity(
+ grad_y.values, name="grad_ys_%d_values" % i) if isinstance(
+ grad_y.values, ops.Tensor) else grad_y.values),
+ dense_shape=(array_ops.identity(
+ grad_y.dense_shape, name="grad_ys_%d_shape" % i)
+ if isinstance(grad_y.dense_shape, ops.Tensor) else
+ grad_y.dense_shape)))
+ else:
+ new_grad_ys.append(array_ops.identity(grad_y, name="grad_ys_%d" % i))
+
+ return new_grad_ys
+
+
+def IsTrainable(tensor_or_dtype):
+ if isinstance(tensor_or_dtype, ops.Tensor):
+ dtype = tensor_or_dtype.dtype
+ else:
+ dtype = tensor_or_dtype
+ dtype = dtypes.as_dtype(dtype)
+ return dtype.base_dtype in (dtypes.float16, dtypes.float32, dtypes.float64,
+ dtypes.complex64, dtypes.complex128,
+ dtypes.resource, dtypes.variant)
+
+
+def _IsBackpropagatable(tensor):
+ if IsTrainable(tensor):
+ return True
+ dtype = dtypes.as_dtype(tensor.dtype)
+ return dtype.base_dtype == dtypes.bfloat16
+
+
+def _VerifyGeneratedGradients(grads, op):
+ """Verify that gradients are valid in number and type.
+
+ Args:
+ grads: List of generated gradients.
+ op: Operation for which the gradients where generated.
+
+ Raises:
+ ValueError: if sizes of gradients and inputs don't match.
+ TypeError: if type of any gradient is not valid for its input.
+ """
+ # While ops have inputs added to them during the gradient computation, so we
+ # skip the below check. See while_v2 for details.
+ if op.type == "While": return
+
+ if len(grads) != len(op.inputs):
+ raise ValueError("Num gradients %d generated for op %s do not match num "
+ "inputs %d" % (len(grads), op.node_def, len(op.inputs)))
+
+
+def _StopOps(from_ops, stop_gradient_ops, pending_count, xs):
+ """The set of ops that terminate the gradient computation.
+
+ This computes the frontier of the forward graph *before* which backprop
+ should stop. Operations in the returned set will not be differentiated.
+ This set is defined as the subset of `from_ops` containing ops that have
+ no predecessor in `from_ops`. `pending_count` is the result of
+ `_PendingCount(xs, from_ops)`. An 'op' has predecessors in `from_ops`
+ iff pending_count[op] > 0.
+
+ In addition, none of `stop_gradient_ops` will be differentiated.
+
+ Args:
+ from_ops: list of Operations.
+ stop_gradient_ops: list of Operations never to backprop through.
+ pending_count: mapping from operation to number of backprop inputs.
+ xs: list of Tensors.
+
+ Returns:
+ The set of operations.
+ """
+ stop_ops = set()
+ for op in from_ops:
+ is_stop_op = True
+ for inp in _NonEagerInputs(op, xs):
+ if pending_count[inp.op] > 0:
+ is_stop_op = False
+ break
+ if is_stop_op:
+ stop_ops.add(op)
+ stop_ops.update(op for op in stop_gradient_ops)
+ return stop_ops
+
+
+@contextlib.contextmanager
+def _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops): # pylint: disable=invalid-name
+ """Context to colocate with `op` if `colocate_gradients_with_ops`."""
+ if colocate_gradients_with_ops:
+ with ops._colocate_with_for_gradient(op, gradient_uid): # pylint: disable=protected-access
+ yield
+ else:
+ yield
+
+
+def _IsPartitionedCall(op):
+ return op.type == "PartitionedCall" or op.type == "StatefulPartitionedCall"
+
+
+def _SymGrad(op, out_grads):
+ """Backprop through a function call node op given its outputs' gradients."""
+ f_in = [x for x in op.inputs] + out_grads
+ f_types = [x.dtype for x in op.inputs]
+ f = attr_value_pb2.NameAttrList()
+ if _IsPartitionedCall(op):
+ f.name = op.get_attr("f").name
+ else:
+ f.name = op.type
+ for k in op.node_def.attr:
+ f.attr[k].CopyFrom(op.node_def.attr[k])
+ # TODO(apassos) use a better dtype here
+ in_grads = functional_ops.symbolic_gradient(
+ input=f_in,
+ Tout=[x if x != dtypes.resource else dtypes.float32 for x in f_types],
+ f=f)
+ return in_grads
+
+
+def _MaybeCompile(scope, op, func, grad_fn):
+ """Compile the calculation in grad_fn if op was marked as compiled."""
+ scope = scope.rstrip("/").replace("/", "_")
+ if func is not None:
+ xla_compile = func.definition.attr["_XlaCompile"].b
+ xla_separate_compiled_gradients = func.definition.attr[
+ "_XlaSeparateCompiledGradients"].b
+ xla_scope = func.definition.attr["_XlaScope"].s.decode()
+ else:
+ try:
+ xla_compile = op.get_attr("_XlaCompile")
+ xla_separate_compiled_gradients = op.get_attr(
+ "_XlaSeparateCompiledGradients")
+ xla_scope = op.get_attr("_XlaScope").decode()
+ except ValueError:
+ return grad_fn() # Exit early
+
+ if not xla_compile:
+ return grad_fn() # Exit early
+
+ # If the gradients are supposed to be compiled separately, we give them a
+ # _XlaScope name that is based on the name_scope of the gradients. Otherwise
+ # they just inherit the existing _XlaScope name, which lets them be merged
+ # together with the non-gradient computation.
+ if xla_separate_compiled_gradients:
+ xla_grad_scope = "%s_grad_%s" % (xla_scope, scope)
+ else:
+ xla_grad_scope = xla_scope
+
+ attrs = {
+ "_XlaCompile": attr_value_pb2.AttrValue(b=xla_compile),
+ "_XlaScope": attr_value_pb2.AttrValue(s=xla_grad_scope.encode())
+ }
+ with ops.get_default_graph()._attr_scope(attrs): # pylint: disable=protected-access
+ return grad_fn()
+
+
+def _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs):
+ """Raises an error if we backprop through a loop var."""
+ # Find the nearest 'to_op' reachable from 'op' to provide a more helpful error
+ # message.
+ target_op = None
+ queue = collections.deque([op])
+ visited = set()
+ while queue:
+ curr_op = queue.popleft()
+ if curr_op in visited: continue
+ visited.add(curr_op)
+ if curr_op in from_ops:
+ target_op = curr_op
+ break
+ queue.extend(t.op for t in _NonEagerInputs(curr_op, xs))
+ assert target_op
+ raise ValueError(
+ "Cannot compute gradient inside while loop with respect to op '%s'. "
+ "We do not support taking the gradient wrt or through the initial value "
+ "of a loop variable. Gradients can be computed through loop invariants "
+ "or wrt the input parameters to the loop body."
+ % target_op.name)
+
+
+def _IsFunction(graph):
+ return (isinstance(graph, FuncGraph) or
+ isinstance(graph, framework_function._FuncGraph)) # pylint: disable=protected-access
+
+
+def _Captures(func_graph):
+ if isinstance(func_graph, FuncGraph):
+ return func_graph.captures
+ else:
+ assert isinstance(func_graph, framework_function._FuncGraph) # pylint: disable=protected-access
+ return func_graph._captured # pylint: disable=protected-access
+
+
+def _MaybeCaptured(t):
+ """If t is a captured value placeholder, returns the original captured value.
+
+ Args:
+ t: Tensor
+
+ Returns:
+ A tensor, potentially from a different Graph/FuncGraph.
+ """
+ # pylint: disable=protected-access
+ if (not isinstance(t, ops.EagerTensor) and
+ _IsFunction(t.op.graph) and t.op.type == "Placeholder"):
+ for input_t, placeholder_t in _Captures(t.op.graph).items():
+ if t == placeholder_t:
+ return _MaybeCaptured(input_t)
+ # pylint: enable=protected-access
+ return t
+
+
+# TODO(skyewm): plumbing xs through everywhere is ugly, consider making
+# _GradientsHelper a class with xs as a member variable.
+def _NonEagerInputs(op, xs):
+ """Returns the inputs of op, crossing closure boundaries where necessary.
+
+ Does not return any captured EagerTensors, i.e., the number of tensors
+ returned may be less than than the actual number of inputs.
+
+ Args:
+ op: Operation
+ xs: list of Tensors we are differentiating w.r.t.
+
+ Returns:
+ A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op
+ is in a FuncGraph and has captured inputs.
+ """
+ if _IsFunction(op.graph): # pylint: disable=protected-access
+ inputs = []
+ for t in op.inputs:
+ # If we're differentiating w.r.t. `t`, do not attempt to traverse through
+ # it to a captured value. The algorithm needs to "see" `t` in this case,
+ # even if it's a function input for a captured value, whereas usually we'd
+ # like to traverse through these closures as if the captured value was the
+ # direct input to op.
+ if t not in xs:
+ t = _MaybeCaptured(t)
+ # Skip captured eager inputs.
+ if isinstance(t, ops.EagerTensor): continue
+ inputs.append(t)
+ return inputs
+ else:
+ return op.inputs
+
+
+def _Consumers(t, func_graphs):
+ """Returns the consumers of t, crossing closure boundaries where necessary.
+
+ Args:
+ t: Tensor
+ func_graphs: a list of FuncGraphs that may have captured t.
+
+ Returns:
+ A list of tensors. The tensors will be from the current graph and/or
+ func_graphs.
+ """
+ consumers = t.consumers()
+ for func in func_graphs:
+ for input_t, placeholder in _Captures(func).items():
+ if input_t == t:
+ consumers.extend(_Consumers(placeholder, func_graphs))
+ return consumers
+
+
+def _GradientsHelper(ys,
+ xs,
+ grad_ys=None,
+ name="gradients",
+ colocate_gradients_with_ops=False,
+ gate_gradients=False,
+ aggregation_method=None,
+ stop_gradients=None,
+ unconnected_gradients=UnconnectedGradients.NONE,
+ src_graph=None):
+ """Implementation of gradients()."""
+ if context.executing_eagerly():
+ raise RuntimeError("tf.gradients is not supported when eager execution "
+ "is enabled. Use tf.GradientTape instead.")
+ if src_graph is None:
+ src_graph = ops.get_default_graph()
+ try:
+ unconnected_gradients = UnconnectedGradients(unconnected_gradients)
+ except ValueError:
+ raise ValueError(
+ "Unknown value for unconnected_gradients: %r" % unconnected_gradients)
+
+ # If src_graph is a _FuncGraph (i.e. a function body), gather it and all
+ # ancestor graphs. This is necessary for correctly handling captured values.
+ func_graphs = []
+ curr_graph = src_graph
+ while _IsFunction(curr_graph):
+ func_graphs.append(curr_graph)
+ if isinstance(curr_graph, FuncGraph):
+ curr_graph = curr_graph.outer_graph
+ else:
+ assert isinstance(curr_graph, framework_function._FuncGraph) # pylint: disable=protected-access
+ curr_graph = curr_graph._outer_graph # pylint: disable=protected-access
+
+ ys = _AsList(ys)
+ xs = _AsList(xs)
+ stop_gradients = [] if stop_gradients is None else _AsList(stop_gradients)
+ if grad_ys is None:
+ grad_ys = [None] * len(ys)
+ else:
+ grad_ys = _AsList(grad_ys)
+
+ with ops.name_scope(
+ name, "gradients",
+ list(ys) + list(xs) + list(stop_gradients) + list(grad_ys)) as grad_scope:
+ # Get a uid for this call to gradients that can be used to help
+ # cluster ops for compilation.
+ gradient_uid = ops.get_default_graph().unique_name("uid")
+ ys = ops.convert_n_to_tensor_or_indexed_slices(ys, name="y")
+ xs = [
+ x.handle if resource_variable_ops.is_resource_variable(x) else x
+ for x in xs
+ ]
+ xs = ops.internal_convert_n_to_tensor_or_indexed_slices(
+ xs, name="x", as_ref=True)
+ grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops,
+ gradient_uid)
+
+ # The approach we take here is as follows: Create a list of all ops in the
+ # subgraph between the ys and xs. Visit these ops in reverse order of ids
+ # to ensure that when we visit an op the gradients w.r.t its outputs have
+ # been collected. Then aggregate these gradients if needed, call the op's
+ # gradient function, and add the generated gradients to the gradients for
+ # its input.
+
+ # Initialize the pending count for ops in the connected subgraph from ys
+ # to the xs.
+ to_ops = [t.op for t in ys]
+ from_ops = [t.op for t in xs]
+ stop_gradient_ops = [t.op for t in stop_gradients]
+ reachable_to_ops, pending_count, loop_state = _PendingCount(
+ to_ops, from_ops, colocate_gradients_with_ops, func_graphs, xs)
+
+ # Iterate over the collected ops.
+ #
+ # grads: op => list of gradients received on each output endpoint of the
+ # op. The gradients for each endpoint are initially collected as a list.
+ # When it is time to call the op's gradient function, for each endpoint we
+ # aggregate the list of received gradients into a Add() Operation if there
+ # is more than one.
+ grads = {}
+
+ # Add the initial gradients for the ys.
+ for y, grad_y in zip(ys, grad_ys):
+ _SetGrad(grads, y, grad_y)
+
+ # Initialize queue with to_ops.
+ queue = collections.deque()
+ # Add the ops in 'to_ops' into the queue.
+ to_ops_set = set()
+ for op in to_ops:
+ # 'ready' handles the case where one output gradient relies on
+ # another output's gradient.
+ ready = (pending_count[op] == 0)
+ if ready and op not in to_ops_set and op in reachable_to_ops:
+ to_ops_set.add(op)
+ queue.append(op)
+
+ if loop_state:
+ loop_exits = loop_state.ProcessUnusedLoopExits(pending_count, to_ops_set)
+ for y in loop_exits:
+ if IsTrainable(y):
+ _SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
+ queue.append(y.op)
+
+ stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs)
+ while queue:
+ # generate gradient subgraph for op.
+ op = queue.popleft()
+ with _maybe_colocate_with(op, gradient_uid, colocate_gradients_with_ops):
+ if loop_state:
+ loop_state.EnterGradWhileContext(op, before=True)
+ out_grads = _AggregatedGrads(grads, op, gradient_uid, loop_state,
+ aggregation_method)
+ if loop_state:
+ loop_state.ExitGradWhileContext(op, before=True)
+
+ grad_fn = None
+ func_call = None
+ is_partitioned_call = _IsPartitionedCall(op)
+ # pylint: disable=protected-access
+ is_func_call = (
+ src_graph._is_function(op.type) or is_partitioned_call)
+ # pylint: enable=protected-access
+ has_out_grads = any(isinstance(g, ops.Tensor) or g for g in out_grads)
+ if has_out_grads and (op not in stop_ops):
+ try:
+ grad_fn = ops.get_gradient_function(op)
+ except LookupError:
+ if is_func_call:
+ if is_partitioned_call:
+ func_call = src_graph._get_function( # pylint: disable=protected-access
+ compat.as_bytes(op.get_attr("f").name))
+ else:
+ func_call = src_graph._get_function(op.type) # pylint: disable=protected-access
+ # Note that __defun is not set if the graph is
+ # imported. If it's set, we prefer to access the original
+ # defun.
+ func_call = getattr(op, "__defun", func_call)
+ grad_fn = func_call.python_grad_func
+ else:
+ raise LookupError(
+ "No gradient defined for operation '%s' (op type: %s)" %
+ (op.name, op.type))
+ if loop_state:
+ loop_state.EnterGradWhileContext(op, before=False)
+
+ # NOTE(skyewm): We don't support computing gradients wrt a loop variable
+ # unless it's within the context of a single iteration (i.e. the
+ # gradient is wrt to the loop parameter in the body function, not wrt or
+ # through the initial value). This means if we're in a while loop
+ # context, we should never see a switch node from this context.
+ # pylint: disable=protected-access
+ if (control_flow_util.IsSwitch(op) and
+ op._control_flow_context is not None and
+ op._control_flow_context.IsWhileContext() and
+ op._control_flow_context ==
+ ops.get_default_graph()._get_control_flow_context()):
+ _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs)
+ # pylint: enable=protected-access
+
+ if (grad_fn or is_func_call) and has_out_grads:
+ # NOTE: If _AggregatedGrads didn't compute a value for the i'th
+ # output, it means that the cost does not depend on output[i],
+ # therefore dC/doutput[i] is 0.
+ for i, out_grad in enumerate(out_grads):
+ if (not isinstance(out_grad, ops.Tensor) and not out_grad) and (
+ (not grad_fn and is_func_call) or IsTrainable(op.outputs[i])):
+ # Only trainable outputs or outputs for a function call that
+ # will use SymbolicGradient get a zero gradient. Gradient
+ # functions should ignore the gradient for other outputs.
+ # TODO(apassos) gradients of resource handles might be an
+ # issue here because of zeros.
+ if loop_state:
+ out_grads[i] = loop_state.ZerosLike(op, i)
+ else:
+ out_grads[i] = control_flow_ops.ZerosLikeOutsideLoop(op, i)
+ with ops.name_scope(op.name + "_grad"):
+ # pylint: disable=protected-access
+ with src_graph._original_op(op):
+ # pylint: enable=protected-access
+ if grad_fn:
+ # If grad_fn was found, do not use SymbolicGradient even for
+ # functions.
+ in_grads = _MaybeCompile(grad_scope, op, func_call,
+ lambda: grad_fn(op, *out_grads))
+ else:
+ # For function call ops, we add a 'SymbolicGradient'
+ # node to the graph to compute gradients.
+ in_grads = _MaybeCompile(grad_scope, op, func_call,
+ lambda: _SymGrad(op, out_grads))
+ in_grads = _AsList(in_grads)
+ _VerifyGeneratedGradients(in_grads, op)
+ if gate_gradients and len([x for x in in_grads
+ if x is not None]) > 1:
+ with ops.device(None):
+ with ops._colocate_with_for_gradient( # pylint: disable=protected-access
+ None,
+ gradient_uid,
+ ignore_existing=True):
+ in_grads = control_flow_ops.tuple(in_grads)
+ _LogOpGradients(op, out_grads, in_grads)
+ else:
+ # If no grad_fn is defined or none of out_grads is available,
+ # just propagate a list of None backwards.
+ in_grads = [None] * len(_NonEagerInputs(op, xs))
+ for i, (t_in, in_grad) in enumerate(zip(_NonEagerInputs(op, xs),
+ in_grads)):
+ if in_grad is not None:
+ if (isinstance(in_grad, ops.Tensor) and
+ t_in.dtype != dtypes.resource):
+ try:
+ in_grad.set_shape(t_in.get_shape())
+ except ValueError:
+ raise ValueError(
+ "Incompatible shapes between op input and calculated "
+ "input gradient. Forward operation: %s. Input index: %d. "
+ "Original input shape: %s. "
+ "Calculated input gradient shape: %s" %
+ (op.name, i, t_in.shape, in_grad.shape))
+ _SetGrad(grads, t_in, in_grad)
+ if loop_state:
+ loop_state.ExitGradWhileContext(op, before=False)
+
+ # Update pending count for the inputs of op and enqueue ready ops.
+ _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
+ xs)
+
+ if loop_state:
+ loop_state.PostProcessing()
+ return [_GetGrad(grads, x, unconnected_gradients) for x in xs]
+
+
+def _HasAnyNotNoneGrads(grads, op):
+ """Return true iff op has real gradient."""
+ out_grads = _GetGrads(grads, op)
+ for out_grad in out_grads:
+ if isinstance(out_grad, (ops.Tensor, ops.IndexedSlices)):
+ return True
+ if out_grad and isinstance(out_grad, collections.Sequence):
+ if any(g is not None for g in out_grad):
+ return True
+ return False
+
+
+def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
+ xs):
+ """Update pending count for the inputs of op and enqueue ready ops."""
+ for x in _NonEagerInputs(op, xs):
+ pending_count[x.op] -= 1
+ ready = (pending_count[x.op] == 0)
+ if loop_state and not ready:
+ ready = pending_count[x.op] > 0 and control_flow_util.IsLoopSwitch(x.op)
+ if ready:
+ if control_flow_util.IsLoopExit(x.op):
+ # if x is an exit without real gradient, defer processing them.
+ grad_state = loop_state.GetGradState(x.op, before=False)
+ grad_state.deferred_exits.append(x)
+ grad_state.pending_exits_count -= 1
+ if grad_state.pending_exits_count == 0:
+ # We now have all the exits so process them.
+ has_not_none_grad = False
+ for y in grad_state.deferred_exits:
+ if _HasAnyNotNoneGrads(grads, y.op):
+ has_not_none_grad = True
+ queue.append(y.op)
+ else:
+ grad_state.unused_exits.append(y)
+ if has_not_none_grad:
+ # For an unused exit, if it has trainable outputs, backprop
+ # a zero gradient. Otherwise, just ignore it.
+ for y in grad_state.unused_exits:
+ if IsTrainable(y):
+ _SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
+ queue.append(y.op)
+ else:
+ # All exits are "unused" so use None as gradient.
+ for y in grad_state.unused_exits:
+ queue.append(y.op)
+ else:
+ queue.append(x.op)
+
+
+def _SetGrad(grads, t, grad):
+ """Sets gradient "grad" in "grads" for tensor "t"."""
+ op = t.op
+ op_grads = grads.get(op)
+ if not op_grads:
+ op_grads = [[] for _ in xrange(len(op.outputs))]
+ grads[op] = op_grads
+ t_grads = op_grads[t.value_index]
+ if isinstance(t_grads, list):
+ t_grads.append(grad)
+ else:
+ assert control_flow_util.IsLoopSwitch(op)
+ op_grads[t.value_index] = grad
+
+
+def _GetGrad(grads, t, unconnected_gradients):
+ """Gets gradient for tensor "t"."""
+ op = t.op
+ op_grads = grads.get(op)
+ if not op_grads:
+ if unconnected_gradients == UnconnectedGradients.ZERO:
+ t_dtype = t.dtype if t.dtype != dtypes.resource else dtypes.float32
+ return array_ops.zeros_like(t, dtype=t_dtype)
+ elif unconnected_gradients == UnconnectedGradients.NONE:
+ return None
+ else:
+ raise ValueError(
+ "Unknown value for unconnected_gradients: %r" % unconnected_gradients)
+
+ t_grad = op_grads[t.value_index]
+ assert not isinstance(
+ t_grad, list), ("gradients list should have been aggregated by now.")
+ return t_grad
+
+
+def _GetGrads(grads, op):
+ """Gets all gradients for op."""
+ if op in grads:
+ return grads[op]
+ else:
+ return [[] for _ in xrange(len(op.outputs))]
+
+
+def _HandleNestedIndexedSlices(grad):
+ assert isinstance(grad, ops.IndexedSlices)
+ if isinstance(grad.values, ops.Tensor):
+ return grad
+ else:
+ assert isinstance(grad.values, ops.IndexedSlices)
+ g = _HandleNestedIndexedSlices(grad.values)
+ return ops.IndexedSlices(g.values, array_ops.gather(
+ grad.indices, g.indices), g.dense_shape)
+
+
+def _AccumulatorShape(inputs):
+ shape = tensor_shape.unknown_shape()
+ for i in inputs:
+ if isinstance(i, ops.Tensor):
+ shape = shape.merge_with(i.get_shape())
+ return shape
+
+
+def _LogOpGradients(op, out_grads, in_grads):
+ """Log the in and out grads of an op."""
+ logging.vlog(1, "Gradient for '" + op.name + "'")
+
+ def _FilterGrad(x):
+ if x is None:
+ return False
+ if isinstance(x, (list, tuple)):
+ return bool(x)
+ else:
+ return True
+
+ logging.vlog(1, " in --> %s",
+ ", ".join([x.name for x in out_grads if _FilterGrad(x)]))
+ logging.vlog(1, " out --> %s",
+ ", ".join([x.name for x in in_grads if _FilterGrad(x)]))
+
+
+def _MultiDeviceAddN(tensor_list, gradient_uid):
+ """Adds tensors from potentially multiple devices."""
+ # Basic function structure comes from control_flow_ops.group().
+ # Sort tensors according to their devices.
+ tensors_on_device = collections.defaultdict(lambda: [])
+ for tensor in tensor_list:
+ tensors_on_device[tensor.device].append(tensor)
+
+ # For each device, add the tensors on that device first.
+ # Then gather the partial sums from multiple devices.
+ # TODO(sjhwang): Create hierarchical aggregation tree as pbar's suggestion.
+ # E.g., aggregate per GPU, then per task, and so on.
+ summands = []
+
+ def DeviceKey(dev):
+ return "" if dev is None else dev
+
+ for dev in sorted(six.iterkeys(tensors_on_device), key=DeviceKey):
+ tensors = tensors_on_device[dev]
+ with ops._colocate_with_for_gradient( # pylint: disable=protected-access
+ tensors[0].op,
+ gradient_uid,
+ ignore_existing=True):
+ summands.append(math_ops.add_n(tensors))
+
+ return math_ops.add_n(summands)
+
+
+@tf_export("AggregationMethod")
+class AggregationMethod(object):
+ """A class listing aggregation methods used to combine gradients.
+
+ Computing partial derivatives can require aggregating gradient
+ contributions. This class lists the various methods that can
+ be used to combine gradients in the graph:
+
+ * `ADD_N`: All of the gradient terms are summed as part of one
+ operation using the "AddN" op. It has the property that all
+ gradients must be ready before any aggregation is performed.
+ * `DEFAULT`: The system-chosen default aggregation method.
+ """
+ ADD_N = 0
+ DEFAULT = ADD_N
+ # The following are experimental and may not be supported in future releases.
+ EXPERIMENTAL_TREE = 1
+ EXPERIMENTAL_ACCUMULATE_N = 2
+
+
+def _AggregatedGrads(grads,
+ op,
+ gradient_uid,
+ loop_state,
+ aggregation_method=None):
+ """Get the aggregated gradients for op.
+
+ Args:
+ grads: The map of memoized gradients.
+ op: The op to get gradients for.
+ gradient_uid: A unique identifier within the graph indicating
+ which invocation of gradients is being executed. Used to cluster
+ ops for compilation.
+ loop_state: An object for maintaining the state of the while loops in the
+ graph. It is of type ControlFlowState. None if the graph
+ contains no while loops.
+ aggregation_method: Specifies the method used to combine gradient terms.
+ Accepted values are constants defined in the class `AggregationMethod`.
+
+ Returns:
+ A list of gradients, one per each output of `op`. If the gradients
+ for a particular output is a list, this function aggregates it
+ before returning.
+
+ Raises:
+ TypeError: if the incoming grads are not Tensors or IndexedSlices.
+ ValueError: if the arguments are invalid.
+
+ """
+ if aggregation_method is None:
+ aggregation_method = AggregationMethod.DEFAULT
+ if aggregation_method not in [
+ AggregationMethod.ADD_N, AggregationMethod.EXPERIMENTAL_TREE,
+ AggregationMethod.EXPERIMENTAL_ACCUMULATE_N
+ ]:
+ raise ValueError(
+ "Invalid aggregation_method specified %s." % aggregation_method)
+ out_grads = _GetGrads(grads, op)
+ for i, out_grad in enumerate(out_grads):
+ if loop_state:
+ if isinstance(out_grad, (ops.Tensor, ops.IndexedSlices)):
+ assert control_flow_util.IsLoopSwitch(op)
+ continue
+ # Grads have to be Tensors or IndexedSlices
+ if (isinstance(out_grad, collections.Sequence) and not all(
+ isinstance(g, (ops.Tensor, ops.IndexedSlices))
+ for g in out_grad
+ if g is not None
+ )):
+ raise TypeError("gradients have to be either all Tensors "
+ "or all IndexedSlices")
+ # Aggregate multiple gradients, and convert [] to None.
+ if out_grad:
+ if len(out_grad) < 2:
+ used = "nop"
+ out_grads[i] = out_grad[0]
+ elif all(isinstance(g, ops.Tensor) for g in out_grad if g is not None):
+ tensor_shape = _AccumulatorShape(out_grad)
+ if (aggregation_method == AggregationMethod.EXPERIMENTAL_ACCUMULATE_N
+ and len(out_grad) > 2 and tensor_shape.is_fully_defined()):
+ # The benefit of using AccumulateN is that its inputs can be combined
+ # in any order and this can allow the expression to be evaluated with
+ # a smaller memory footprint. When used with gpu_allocator_retry,
+ # it is possible to compute a sum of terms which are much larger than
+ # total GPU memory.
+ # AccumulateN can currently only be used if we know the shape for
+ # an accumulator variable. If this is not known, or if we only have
+ # 2 grads then we fall through to the "tree" case below.
+ used = "accumulate_n"
+ out_grads[i] = math_ops.accumulate_n(out_grad)
+ elif aggregation_method in [
+ AggregationMethod.EXPERIMENTAL_TREE,
+ AggregationMethod.EXPERIMENTAL_ACCUMULATE_N
+ ]:
+ # Aggregate all gradients by doing pairwise sums: this may
+ # reduce performance, but it can improve memory because the
+ # gradients can be released earlier.
+ #
+ # TODO(vrv): Consider replacing this with a version of
+ # tf.AddN() that eagerly frees its inputs as soon as they are
+ # ready, so the order of this tree does not become a problem.
+ used = "tree"
+ with ops.name_scope(op.name + "_gradient_sum"):
+ running_sum = out_grad[0]
+ for grad in out_grad[1:]:
+ running_sum = math_ops.add_n([running_sum, grad])
+ out_grads[i] = running_sum
+ else:
+ used = "add_n"
+ out_grads[i] = _MultiDeviceAddN(out_grad, gradient_uid)
+ logging.vlog(2, " _AggregatedGrads %d x %s using %s", len(out_grad),
+ tensor_shape, used)
+ else:
+ out_grads[i] = _AggregateIndexedSlicesGradients(out_grad)
+ else: # not out_grad
+ # out_grads[i] is [], thus its aggregation is simply None.
+ out_grads[i] = None
+ return out_grads
+
+
+def _AggregateIndexedSlicesGradients(grads):
+ """Aggregates gradients of type `IndexedSlices` by concatenation."""
+ if len(grads) < 1:
+ return None
+ elif len(grads) == 1:
+ return grads[0]
+ else:
+ grads = math_ops._as_indexed_slices_list( # pylint: disable=protected-access
+ [g for g in grads if g is not None])
+ grads = [_HandleNestedIndexedSlices(x) for x in grads] # pylint: disable=protected-access
+ # Form IndexedSlices out of the concatenated values and indices.
+ concat_grad = ops.IndexedSlices(
+ array_ops.concat([x.values for x in grads], axis=0),
+ array_ops.concat([x.indices for x in grads], axis=0),
+ grads[0].dense_shape)
+
+ return concat_grad
diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py
index 361befa..b032d64 100644
--- a/tensorflow/python/ops/image_ops_test.py
+++ b/tensorflow/python/ops/image_ops_test.py
@@ -61,7 +61,7 @@
inp = np.random.rand(*shape).astype(nptype)
# Convert to HSV and back, as a batch and individually
- with self.test_session(use_gpu=True) as sess:
+ with self.cached_session(use_gpu=True) as sess:
batch0 = constant_op.constant(inp)
batch1 = image_ops.rgb_to_hsv(batch0)
batch2 = image_ops.hsv_to_rgb(batch1)
@@ -82,7 +82,7 @@
data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1]
for nptype in [np.float32, np.float64]:
rgb_np = np.array(data, dtype=nptype).reshape([2, 2, 3]) / 255.
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
hsv = image_ops.rgb_to_hsv(rgb_np)
rgb = image_ops.hsv_to_rgb(hsv)
rgb_tf = self.evaluate(rgb)
@@ -101,7 +101,7 @@
inp = np.random.rand(*shape).astype(nptype)
# Convert to YIQ and back, as a batch and individually
- with self.test_session(use_gpu=True) as sess:
+ with self.cached_session(use_gpu=True) as sess:
batch0 = constant_op.constant(inp)
batch1 = image_ops.rgb_to_yiq(batch0)
batch2 = image_ops.yiq_to_rgb(batch1)
@@ -131,7 +131,7 @@
inp = np.random.rand(*shape).astype(nptype)
# Convert to YUV and back, as a batch and individually
- with self.test_session(use_gpu=True) as sess:
+ with self.cached_session(use_gpu=True) as sess:
batch0 = constant_op.constant(inp)
batch1 = image_ops.rgb_to_yuv(batch0)
batch2 = image_ops.yuv_to_rgb(batch1)
@@ -173,7 +173,7 @@
def _TestRGBToGrayscale(self, x_np):
y_np = self._RGBToGrayscale(x_np)
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.rgb_to_grayscale(x_tf)
y_tf = self.evaluate(y)
@@ -195,7 +195,7 @@
y_np = np.array(
[[1, 1, 1], [2, 2, 2]], dtype=np.uint8).reshape([1, 1, 2, 3])
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.grayscale_to_rgb(x_tf)
y_tf = self.evaluate(y)
@@ -205,7 +205,7 @@
x_np = np.array([[1, 2]], dtype=np.uint8).reshape([1, 2, 1])
y_np = np.array([[1, 1, 1], [2, 2, 2]], dtype=np.uint8).reshape([1, 2, 3])
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.grayscale_to_rgb(x_tf)
y_tf = self.evaluate(y)
@@ -216,23 +216,23 @@
# Shape inference works and produces expected output where possible
rgb_shape = [7, None, 19, 3]
gray_shape = rgb_shape[:-1] + [1]
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
rgb_tf = array_ops.placeholder(dtypes.uint8, shape=rgb_shape)
gray = image_ops.rgb_to_grayscale(rgb_tf)
self.assertEqual(gray_shape, gray.get_shape().as_list())
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
gray_tf = array_ops.placeholder(dtypes.uint8, shape=gray_shape)
rgb = image_ops.grayscale_to_rgb(gray_tf)
self.assertEqual(rgb_shape, rgb.get_shape().as_list())
# Shape inference does not break for unknown shapes
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
rgb_tf_unknown = array_ops.placeholder(dtypes.uint8)
gray_unknown = image_ops.rgb_to_grayscale(rgb_tf_unknown)
self.assertFalse(gray_unknown.get_shape())
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
gray_tf_unknown = array_ops.placeholder(dtypes.uint8)
rgb_unknown = image_ops.grayscale_to_rgb(gray_tf_unknown)
self.assertFalse(rgb_unknown.get_shape())
@@ -364,7 +364,7 @@
y_data = [0, 13, 1, 54, 226, 59, 8, 234, 150, 255, 39, 1]
y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
x = constant_op.constant(x_np, shape=x_shape)
y = image_ops.adjust_hue(x, delta)
y_tf = self.evaluate(y)
@@ -379,7 +379,7 @@
y_data = [13, 0, 11, 226, 54, 221, 234, 8, 92, 1, 217, 255]
y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
x = constant_op.constant(x_np, shape=x_shape)
y = image_ops.adjust_hue(x, delta)
y_tf = self.evaluate(y)
@@ -394,7 +394,7 @@
y_data = [13, 0, 11, 226, 54, 221, 234, 8, 92, 1, 217, 255]
y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
x = constant_op.constant(x_np, shape=x_shape)
y = image_ops.adjust_hue(x, delta)
y_tf = self.evaluate(y)
@@ -419,7 +419,7 @@
return y_v.reshape(x_np.shape)
def _adjustHueTf(self, x_np, delta_h):
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
x = constant_op.constant(x_np)
y = image_ops.adjust_hue(x, delta_h)
y_tf = self.evaluate(y)
@@ -850,7 +850,7 @@
y_data = [6, 9, 13, 140, 180, 226, 135, 121, 234, 172, 255, 128]
y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
x = constant_op.constant(x_np, shape=x_shape)
y = image_ops.adjust_saturation(x, saturation_factor)
y_tf = self.evaluate(y)
@@ -865,7 +865,7 @@
y_data = [0, 5, 13, 0, 106, 226, 30, 0, 234, 89, 255, 0]
y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
x = constant_op.constant(x_np, shape=x_shape)
y = image_ops.adjust_saturation(x, saturation_factor)
y_tf = self.evaluate(y)
@@ -880,7 +880,7 @@
y_data = [6, 9, 13, 140, 180, 226, 135, 121, 234, 172, 255, 128]
y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape)
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
x = constant_op.constant(x_np, shape=x_shape)
y = image_ops.adjust_saturation(x, saturation_factor)
y_tf = self.evaluate(y)
@@ -920,7 +920,7 @@
"gb_same",
"rgb_same",
]
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
for x_shape in x_shapes:
for test_style in test_styles:
x_np = np.random.rand(*x_shape) * 255.
@@ -947,7 +947,7 @@
def testInvolutionLeftRight(self):
x_np = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.uint8).reshape([2, 3, 1])
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.flip_left_right(image_ops.flip_left_right(x_tf))
y_tf = self.evaluate(y)
@@ -957,7 +957,7 @@
x_np = np.array(
[[[1, 2, 3], [1, 2, 3]], [[1, 2, 3], [1, 2, 3]]],
dtype=np.uint8).reshape([2, 2, 3, 1])
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.flip_left_right(image_ops.flip_left_right(x_tf))
y_tf = self.evaluate(y)
@@ -968,7 +968,7 @@
x_np = np.array([[1, 2, 3], [1, 2, 3]], dtype=np.uint8).reshape([2, 3, 1])
y_np = np.array([[3, 2, 1], [3, 2, 1]], dtype=np.uint8).reshape([2, 3, 1])
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.flip_left_right(x_tf)
self.assertTrue(y.op.name.startswith("flip_left_right"))
@@ -983,7 +983,7 @@
[[[3, 2, 1], [3, 2, 1]], [[3, 2, 1], [3, 2, 1]]],
dtype=np.uint8).reshape([2, 2, 3, 1])
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.flip_left_right(x_tf)
y_tf = self.evaluate(y)
@@ -995,7 +995,7 @@
y_np = np.array([[3, 2, 1], [3, 2, 1]], dtype=np.uint8).reshape([2, 3, 1])
seed = 42
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.random_flip_left_right(x_tf, seed=seed)
self.assertTrue(y.op.name.startswith("random_flip_left_right"))
@@ -1035,7 +1035,7 @@
x_np = np.vstack([x_np_raw for _ in range(batch_size)])
y_np = np.vstack([y_np_raw for _ in range(batch_size)])
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.random_flip_left_right(x_tf, seed=seed)
self.assertTrue(y.op.name.startswith("random_flip_left_right"))
@@ -1066,7 +1066,7 @@
def testInvolutionUpDown(self):
x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1])
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.flip_up_down(image_ops.flip_up_down(x_tf))
y_tf = self.evaluate(y)
@@ -1077,7 +1077,7 @@
[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
dtype=np.uint8).reshape([2, 2, 3, 1])
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.flip_up_down(image_ops.flip_up_down(x_tf))
y_tf = self.evaluate(y)
@@ -1088,7 +1088,7 @@
x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1])
y_np = np.array([[4, 5, 6], [1, 2, 3]], dtype=np.uint8).reshape([2, 3, 1])
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.flip_up_down(x_tf)
self.assertTrue(y.op.name.startswith("flip_up_down"))
@@ -1103,7 +1103,7 @@
[[[4, 5, 6], [1, 2, 3]], [[10, 11, 12], [7, 8, 9]]],
dtype=np.uint8).reshape([2, 2, 3, 1])
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.flip_up_down(x_tf)
y_tf = self.evaluate(y)
@@ -1116,7 +1116,7 @@
seed = 42
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.random_flip_up_down(x_tf, seed=seed)
self.assertTrue(y.op.name.startswith("random_flip_up_down"))
@@ -1155,7 +1155,7 @@
x_np = np.vstack([x_np_raw for _ in range(batch_size)])
y_np = np.vstack([y_np_raw for _ in range(batch_size)])
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.random_flip_up_down(x_tf, seed=seed)
self.assertTrue(y.op.name.startswith("random_flip_up_down"))
@@ -1186,7 +1186,7 @@
def testInvolutionTranspose(self):
x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1])
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.transpose_image(image_ops.transpose_image(x_tf))
y_tf = self.evaluate(y)
@@ -1197,7 +1197,7 @@
[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]],
dtype=np.uint8).reshape([2, 2, 3, 1])
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.transpose_image(image_ops.transpose_image(x_tf))
y_tf = self.evaluate(y)
@@ -1208,7 +1208,7 @@
x_np = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.uint8).reshape([2, 3, 1])
y_np = np.array([[1, 4], [2, 5], [3, 6]], dtype=np.uint8).reshape([3, 2, 1])
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.transpose_image(x_tf)
self.assertTrue(y.op.name.startswith("transpose"))
@@ -1224,7 +1224,7 @@
[[[1, 4], [2, 5], [3, 6]], [[7, 10], [8, 11], [9, 12]]],
dtype=np.uint8).reshape([2, 3, 2, 1])
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
x_tf = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.transpose_image(x_tf)
y_tf = self.evaluate(y)
@@ -1275,7 +1275,7 @@
def testRot90GroupOrder(self):
image = np.arange(24, dtype=np.uint8).reshape([2, 4, 3])
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
rotated = image
for _ in xrange(4):
rotated = image_ops.rot90(rotated)
@@ -1283,7 +1283,7 @@
def testRot90GroupOrderWithBatch(self):
image = np.arange(48, dtype=np.uint8).reshape([2, 2, 4, 3])
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
rotated = image
for _ in xrange(4):
rotated = image_ops.rot90(rotated)
@@ -1292,7 +1292,7 @@
@test_util.run_deprecated_v1
def testRot90NumpyEquivalence(self):
image = np.arange(24, dtype=np.uint8).reshape([2, 4, 3])
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
k_placeholder = array_ops.placeholder(dtypes.int32, shape=[])
y_tf = image_ops.rot90(image, k_placeholder)
for k in xrange(4):
@@ -1302,7 +1302,7 @@
@test_util.run_deprecated_v1
def testRot90NumpyEquivalenceWithBatch(self):
image = np.arange(48, dtype=np.uint8).reshape([2, 2, 4, 3])
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
k_placeholder = array_ops.placeholder(dtypes.int32, shape=[])
y_tf = image_ops.rot90(image, k_placeholder)
for k in xrange(4):
@@ -1312,7 +1312,7 @@
class AdjustContrastTest(test_util.TensorFlowTestCase):
def _testContrast(self, x_np, y_np, contrast_factor):
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
x = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.adjust_contrast(x, contrast_factor)
y_tf = self.evaluate(y)
@@ -1367,7 +1367,7 @@
return y_np
def _adjustContrastTf(self, x_np, contrast_factor):
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
x = constant_op.constant(x_np)
y = image_ops.adjust_contrast(x, contrast_factor)
y_tf = self.evaluate(y)
@@ -1401,7 +1401,7 @@
class AdjustBrightnessTest(test_util.TensorFlowTestCase):
def _testBrightness(self, x_np, y_np, delta, tol=1e-6):
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
x = constant_op.constant(x_np, shape=x_np.shape)
y = image_ops.adjust_brightness(x, delta)
y_tf = self.evaluate(y)
@@ -1468,7 +1468,7 @@
x_np = np.arange(0, np.prod(x_shape), dtype=np.int32).reshape(x_shape)
y_np = self._NumpyPerImageWhitening(x_np)
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
x = constant_op.constant(x_np, shape=x_shape)
y = image_ops.per_image_standardization(x)
self.assertTrue(y.op.name.startswith("per_image_standardization"))
@@ -1479,14 +1479,14 @@
im_np = np.ones([19, 19, 3]).astype(np.float32) * 249
im = constant_op.constant(im_np)
whiten = image_ops.per_image_standardization(im)
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
whiten_np = self.evaluate(whiten)
self.assertFalse(np.any(np.isnan(whiten_np)))
def testBatchWhitening(self):
imgs_np = np.random.uniform(0., 255., [4, 24, 24, 3])
whiten_np = [self._NumpyPerImageWhitening(img) for img in imgs_np]
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
imgs = constant_op.constant(imgs_np)
whiten = image_ops.per_image_standardization(imgs)
whiten_tf = self.evaluate(whiten)
@@ -1514,7 +1514,7 @@
if not use_tensor_inputs:
self.assertTrue(y.get_shape().is_fully_defined())
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
return y.eval(feed_dict=feed_dict)
def _assertReturns(self,
@@ -1693,7 +1693,7 @@
for x_shape in x_shapes:
x_np = np.ones(x_shape, dtype=np.float32)
for use_gpu in [True, False]:
- with self.test_session(use_gpu=use_gpu):
+ with self.cached_session(use_gpu=use_gpu):
x = constant_op.constant(x_np, shape=x_shape)
y = image_ops.central_crop(x, 1.0)
y_tf = self.evaluate(y)
@@ -1708,7 +1708,7 @@
dtype=np.int32).reshape(x_shape)
y_np = np.array([[3, 4, 5, 6], [3, 4, 5, 6]]).reshape([2, 4, 1])
for use_gpu in [True, False]:
- with self.test_session(use_gpu=use_gpu):
+ with self.cached_session(use_gpu=use_gpu):
x = constant_op.constant(x_np, shape=x_shape)
y = image_ops.central_crop(x, 0.5)
y_tf = self.evaluate(y)
@@ -1724,7 +1724,7 @@
dtype=np.int32).reshape(x_shape)
y_np = np.array([[[3, 4, 5, 6], [3, 4, 5, 6]],
[[6, 5, 4, 3], [6, 5, 4, 3]]]).reshape([2, 2, 4, 1])
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
x = constant_op.constant(x_np, shape=x_shape)
y = image_ops.central_crop(x, 0.5)
y_tf = self.evaluate(y)
@@ -1741,7 +1741,7 @@
x_np = np.zeros(x_shape, dtype=np.int32)
y_np = np.zeros(y_shape, dtype=np.int32)
for use_gpu in [True, False]:
- with self.test_session(use_gpu=use_gpu):
+ with self.cached_session(use_gpu=use_gpu):
x = array_ops.placeholder(shape=x_shape, dtype=dtypes.int32)
y = image_ops.central_crop(x, 0.33)
y_tf = y.eval(feed_dict={x: x_np})
@@ -1792,7 +1792,7 @@
x_shape = [13, 9, 3]
x_np = np.ones(x_shape, dtype=np.float32)
for use_gpu in [True, False]:
- with self.test_session(use_gpu=use_gpu):
+ with self.cached_session(use_gpu=use_gpu):
x = constant_op.constant(x_np, shape=x_shape)
with self.assertRaises(ValueError):
_ = image_ops.central_crop(x, 0.0)
@@ -1804,7 +1804,7 @@
for x_shape in x_shapes:
x_np = np.ones(x_shape, dtype=np.float32)
for use_gpu in [True, False]:
- with self.test_session(use_gpu=use_gpu):
+ with self.cached_session(use_gpu=use_gpu):
x = constant_op.constant(x_np, shape=x_shape)
with self.assertRaises(ValueError):
_ = image_ops.central_crop(x, 0.5)
@@ -1814,7 +1814,7 @@
x_shape = [13, 9, 3]
x_np = np.ones(x_shape, dtype=np.float32)
for use_gpu in [True, False]:
- with self.test_session(use_gpu=use_gpu):
+ with self.cached_session(use_gpu=use_gpu):
y = image_ops.central_crop(x_np, 1.0)
self.assertTrue(y.op.name.startswith("central_crop"))
@@ -1839,7 +1839,7 @@
if not use_tensor_inputs:
self.assertTrue(y.get_shape().is_fully_defined())
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
return y.eval(feed_dict=feed_dict)
def _assertReturns(self,
@@ -1899,7 +1899,7 @@
i = constant_op.constant([1, 0, 4, 3], dtype=dtypes.int64)
y_tf = image_ops.pad_to_bounding_box(x, i[0], i[1], i[2], i[3])
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
self.assertAllClose(y, self.evaluate(y_tf))
@test_util.run_deprecated_v1
@@ -2034,7 +2034,7 @@
fraction_object_covered = []
num_iter = 1000
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
image_tf = constant_op.constant(image, shape=image.shape)
image_size_tf = constant_op.constant(
image_size_np, shape=image_size_np.shape)
@@ -2164,7 +2164,7 @@
@test_util.run_deprecated_v1
def testSampleDistortedBoundingBoxShape(self):
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
image_size = constant_op.constant(
[40, 50, 1], shape=[3], dtype=dtypes.int32)
bounding_box = constant_op.constant(
@@ -2202,7 +2202,7 @@
def testDefaultMinObjectCovered(self):
# By default min_object_covered=0.1 if not provided
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
image_size = constant_op.constant(
[40, 50, 1], shape=[3], dtype=dtypes.int32)
bounding_box = constant_op.constant(
@@ -2275,7 +2275,7 @@
img_np = np.array(data, dtype=nptype).reshape(img_shape)
for opt in self.OPTIONS:
- with self.test_session(use_gpu=True) as sess:
+ with self.cached_session(use_gpu=True) as sess:
image = constant_op.constant(img_np, shape=img_shape)
y = image_ops.resize_images(image, [target_height, target_width], opt)
yshape = array_ops.shape(y)
@@ -2284,7 +2284,7 @@
self.assertAllClose(resized, img_np, atol=1e-5)
# Resizing with a single image must leave the shape unchanged also.
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
img_single = img_np.reshape(single_shape)
image = constant_op.constant(img_single, shape=single_shape)
y = image_ops.resize_images(image, [target_height, target_width],
@@ -2308,7 +2308,7 @@
img_np = np.array(data, dtype=np.uint8).reshape(img_shape)
for opt in self.OPTIONS:
- with self.test_session(use_gpu=True) as sess:
+ with self.cached_session(use_gpu=True) as sess:
image = constant_op.constant(img_np, shape=img_shape)
y = image_ops.resize_images(image, new_size, opt)
yshape = array_ops.shape(y)
@@ -2317,7 +2317,7 @@
self.assertAllClose(resized, img_np, atol=1e-5)
# Resizing with a single image must leave the shape unchanged also.
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
img_single = img_np.reshape(single_shape)
image = constant_op.constant(img_single, shape=single_shape)
y = image_ops.resize_images(image, new_size, self.OPTIONS[0])
@@ -2422,7 +2422,7 @@
for opt in self.OPTIONS:
if test.is_gpu_available() and self.shouldRunOnGPU(opt, nptype):
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
image = constant_op.constant(img_np, shape=img_shape)
y = image_ops.resize_images(image, [target_height, target_width],
opt)
@@ -2457,7 +2457,7 @@
image_ops.ResizeMethod.BILINEAR,
image_ops.ResizeMethod.NEAREST_NEIGHBOR, image_ops.ResizeMethod.AREA
]:
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
img_np = np.array(data, dtype=nptype).reshape(img_shape)
image = constant_op.constant(img_np, shape=img_shape)
y = image_ops.resize_images(
@@ -2493,7 +2493,7 @@
image_ops.ResizeMethod.BILINEAR,
image_ops.ResizeMethod.NEAREST_NEIGHBOR, image_ops.ResizeMethod.AREA
]:
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
img_np = np.array(data, dtype=nptype).reshape(img_shape)
image = constant_op.constant(img_np, shape=img_shape)
y = image_ops.resize_images(
@@ -2521,7 +2521,7 @@
75, 81, 80, 72, 69, 70, 105, 112, 75, 36, 45, 92, 111, 105
]
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
image = constant_op.constant(img_np, shape=img_shape)
y = image_ops.resize_images(image, [target_height, target_width],
image_ops.ResizeMethod.BICUBIC)
@@ -2544,7 +2544,7 @@
73, 33, 23, 39, 73, 33, 23, 39, 14, 16, 19, 21, 14, 16, 19, 21
]
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
image = constant_op.constant(img_np, shape=img_shape)
y = image_ops.resize_images(image, [target_height, target_width],
image_ops.ResizeMethod.AREA)
@@ -2562,7 +2562,7 @@
for align_corners in [True, False]:
img_np = np.arange(
0, np.prod(input_shape), dtype=nptype).reshape(input_shape)
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
image = constant_op.constant(img_np, shape=input_shape)
new_size = constant_op.constant([target_height, target_width])
out_op = image_ops.resize_images(
@@ -2571,7 +2571,7 @@
image_ops.ResizeMethod.NEAREST_NEIGHBOR,
align_corners=align_corners)
gpu_val = self.evaluate(out_op)
- with self.test_session(use_gpu=False):
+ with self.cached_session(use_gpu=False):
image = constant_op.constant(img_np, shape=input_shape)
new_size = constant_op.constant([target_height, target_width])
out_op = image_ops.resize_images(
@@ -2593,7 +2593,7 @@
0, np.prod(input_shape), dtype=nptype).reshape(input_shape)
value = {}
for use_gpu in [True, False]:
- with self.test_session(use_gpu=use_gpu):
+ with self.cached_session(use_gpu=use_gpu):
image = constant_op.constant(img_np, shape=input_shape)
new_size = constant_op.constant([target_height, target_width])
out_op = image_ops.resize_images(
@@ -2628,7 +2628,7 @@
@test_util.run_deprecated_v1
def testNameScope(self):
img_shape = [1, 3, 2, 1]
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
single_image = array_ops.placeholder(dtypes.float32, shape=[50, 60, 3])
y = image_ops.resize_images(single_image, [55, 66])
self.assertTrue(y.op.name.startswith("resize"))
@@ -2647,7 +2647,7 @@
y = image_ops.resize_images(x_tensor, target_max,
preserve_aspect_ratio=preserve_aspect_ratio)
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
return y.eval(feed_dict=feed_dict)
def _assertResizeEqual(self, x, x_shape, y, y_shape,
@@ -2745,7 +2745,7 @@
if not use_tensor_inputs:
self.assertTrue(y.get_shape().is_fully_defined())
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
return y.eval(feed_dict=feed_dict)
def _assertReturns(self,
@@ -2843,7 +2843,7 @@
if not use_tensor_inputs:
self.assertTrue(y.get_shape().is_fully_defined())
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
return y.eval(feed_dict=feed_dict)
def _assertReturns(self,
@@ -3098,7 +3098,7 @@
# Read a real jpeg and verify shape
path = ("tensorflow/core/lib/jpeg/testdata/"
"jpeg_merge_test1.jpg")
- with self.test_session(use_gpu=True) as sess:
+ with self.cached_session(use_gpu=True) as sess:
jpeg0 = io_ops.read_file(path)
image0 = image_ops.decode_jpeg(jpeg0)
image1 = image_ops.decode_jpeg(image_ops.encode_jpeg(image0))
@@ -3114,7 +3114,7 @@
cmyk_path = os.path.join(base, "jpeg_merge_test1_cmyk.jpg")
shape = 256, 128, 3
for channels in 3, 0:
- with self.test_session(use_gpu=True) as sess:
+ with self.cached_session(use_gpu=True) as sess:
rgb = image_ops.decode_jpeg(
io_ops.read_file(rgb_path), channels=channels)
cmyk = image_ops.decode_jpeg(
@@ -3171,7 +3171,7 @@
self.evaluate(result)
def testSynthetic(self):
- with self.test_session(use_gpu=True) as sess:
+ with self.cached_session(use_gpu=True) as sess:
# Encode it, then decode it, then encode it
image0 = constant_op.constant(_SimpleColorRamp())
jpeg0 = image_ops.encode_jpeg(image0)
@@ -3192,7 +3192,7 @@
self.assertLessEqual(len(jpeg0), 6000)
def testSyntheticFasterAlgorithm(self):
- with self.test_session(use_gpu=True) as sess:
+ with self.cached_session(use_gpu=True) as sess:
# Encode it, then decode it, then encode it
image0 = constant_op.constant(_SimpleColorRamp())
jpeg0 = image_ops.encode_jpeg(image0)
@@ -3216,7 +3216,7 @@
self.assertLessEqual(len(jpeg0), 6000)
def testDefaultDCTMethodIsIntegerFast(self):
- with self.test_session(use_gpu=True) as sess:
+ with self.cached_session(use_gpu=True) as sess:
# Compare decoding with both dct_option=INTEGER_FAST and
# default. They should be the same.
image0 = constant_op.constant(_SimpleColorRamp())
@@ -3230,7 +3230,7 @@
@test_util.run_deprecated_v1
def testShape(self):
- with self.test_session(use_gpu=True) as sess:
+ with self.cached_session(use_gpu=True) as sess:
jpeg = constant_op.constant("nonsense")
for channels in 0, 1, 3:
image = image_ops.decode_jpeg(jpeg, channels=channels)
@@ -3242,7 +3242,7 @@
# Read a real jpeg and verify shape.
path = ("tensorflow/core/lib/jpeg/testdata/"
"jpeg_merge_test1.jpg")
- with self.test_session(use_gpu=True) as sess:
+ with self.cached_session(use_gpu=True) as sess:
jpeg = io_ops.read_file(path)
# Extract shape without decoding.
[image_shape] = sess.run([image_ops.extract_jpeg_shape(jpeg)])
@@ -3253,7 +3253,7 @@
# Read a cmyk jpeg image, and verify its shape.
path = ("tensorflow/core/lib/jpeg/testdata/"
"jpeg_merge_test1_cmyk.jpg")
- with self.test_session(use_gpu=True) as sess:
+ with self.cached_session(use_gpu=True) as sess:
jpeg = io_ops.read_file(path)
[image_shape] = sess.run([image_ops.extract_jpeg_shape(jpeg)])
# Cmyk jpeg image has 4 channels.
@@ -3269,7 +3269,7 @@
(3, "lena_palette.png"), (4, "lena_palette_trns.png"))
for channels_in, filename in inputs:
for channels in 0, 1, 3, 4:
- with self.test_session(use_gpu=True) as sess:
+ with self.cached_session(use_gpu=True) as sess:
png0 = io_ops.read_file(prefix + filename)
image0 = image_ops.decode_png(png0, channels=channels)
png0, image0 = self.evaluate([png0, image0])
@@ -3279,7 +3279,7 @@
self.assertAllEqual(image0, self.evaluate(image1))
def testSynthetic(self):
- with self.test_session(use_gpu=True) as sess:
+ with self.cached_session(use_gpu=True) as sess:
# Encode it, then decode it
image0 = constant_op.constant(_SimpleColorRamp())
png0 = image_ops.encode_png(image0, compression=7)
@@ -3294,7 +3294,7 @@
self.assertLessEqual(len(png0), 750)
def testSyntheticUint16(self):
- with self.test_session(use_gpu=True) as sess:
+ with self.cached_session(use_gpu=True) as sess:
# Encode it, then decode it
image0 = constant_op.constant(_SimpleColorRamp(), dtype=dtypes.uint16)
png0 = image_ops.encode_png(image0, compression=7)
@@ -3309,7 +3309,7 @@
self.assertLessEqual(len(png0), 1500)
def testSyntheticTwoChannel(self):
- with self.test_session(use_gpu=True) as sess:
+ with self.cached_session(use_gpu=True) as sess:
# Strip the b channel from an rgb image to get a two-channel image.
gray_alpha = _SimpleColorRamp()[:, :, 0:2]
image0 = constant_op.constant(gray_alpha)
@@ -3320,7 +3320,7 @@
self.assertAllEqual(image0, image1)
def testSyntheticTwoChannelUint16(self):
- with self.test_session(use_gpu=True) as sess:
+ with self.cached_session(use_gpu=True) as sess:
# Strip the b channel from an rgb image to get a two-channel image.
gray_alpha = _SimpleColorRamp()[:, :, 0:2]
image0 = constant_op.constant(gray_alpha, dtype=dtypes.uint16)
@@ -3332,7 +3332,7 @@
@test_util.run_deprecated_v1
def testShape(self):
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
png = constant_op.constant("nonsense")
for channels in 0, 1, 3:
image = image_ops.decode_png(png, channels=channels)
@@ -3350,7 +3350,7 @@
STRIDE = 5
shape = (12, HEIGHT, WIDTH, 3)
- with self.test_session(use_gpu=True) as sess:
+ with self.cached_session(use_gpu=True) as sess:
gif0 = io_ops.read_file(prefix + filename)
image0 = image_ops.decode_gif(gif0)
gif0, image0 = self.evaluate([gif0, image0])
@@ -3377,7 +3377,7 @@
@test_util.run_deprecated_v1
def testShape(self):
- with self.test_session(use_gpu=True) as sess:
+ with self.cached_session(use_gpu=True) as sess:
gif = constant_op.constant("nonsense")
image = image_ops.decode_gif(gif)
self.assertEqual(image.get_shape().as_list(), [None, None, None, 3])
@@ -3389,7 +3389,7 @@
x_np = np.array(original, dtype=original_dtype.as_numpy_dtype())
y_np = np.array(expected, dtype=output_dtype.as_numpy_dtype())
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
image = constant_op.constant(x_np)
y = image_ops.convert_image_dtype(image, output_dtype)
self.assertTrue(y.dtype == output_dtype)
@@ -3405,7 +3405,7 @@
@test_util.run_deprecated_v1
def testNoConvert(self):
# Make sure converting to the same data type creates only an identity op
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
image = constant_op.constant([1], dtype=dtypes.uint8)
image_ops.convert_image_dtype(image, dtypes.uint8)
y = image_ops.convert_image_dtype(image, dtypes.uint8)
@@ -3415,7 +3415,7 @@
@test_util.run_deprecated_v1
def testConvertBetweenInteger(self):
# Make sure converting to between integer types scales appropriately
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
self._convert([0, 255], dtypes.uint8, dtypes.int16, [0, 255 * 128])
self._convert([0, 32767], dtypes.int16, dtypes.uint8, [0, 255])
self._convert([0, 2**32], dtypes.int64, dtypes.int32, [0, 1])
@@ -3424,7 +3424,7 @@
@test_util.run_deprecated_v1
def testConvertBetweenFloat(self):
# Make sure converting to between float types does nothing interesting
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
self._convert([-1.0, 0, 1.0, 200000], dtypes.float32, dtypes.float64,
[-1.0, 0, 1.0, 200000])
self._convert([-1.0, 0, 1.0, 200000], dtypes.float64, dtypes.float32,
@@ -3433,7 +3433,7 @@
@test_util.run_deprecated_v1
def testConvertBetweenIntegerAndFloat(self):
# Make sure converting from and to a float type scales appropriately
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
self._convert([0, 1, 255], dtypes.uint8, dtypes.float32,
[0, 1.0 / 255.0, 1])
self._convert([0, 1.1 / 255.0, 1], dtypes.float32, dtypes.uint8,
@@ -3441,7 +3441,7 @@
@test_util.run_deprecated_v1
def testConvertBetweenInt16AndInt8(self):
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
# uint8, uint16
self._convert([0, 255 * 256], dtypes.uint16, dtypes.uint8, [0, 255])
self._convert([0, 255], dtypes.uint8, dtypes.uint16, [0, 255 * 256])
@@ -3472,7 +3472,7 @@
"""
# Create a TensorFlow session.
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
# Add a constant to the TensorFlow graph that holds the input.
x_tf = constant_op.constant(x_np, shape=x_np.shape)
@@ -3860,7 +3860,7 @@
img = array_ops.placeholder(dtype=dtypes.float32)
img_np = np.array((2, 2))
- with self.test_session(use_gpu=True) as sess:
+ with self.cached_session(use_gpu=True) as sess:
_, _, checks = image_ops_impl._verify_compatible_image_shapes(img, img)
with self.assertRaises(errors.InvalidArgumentError):
sess.run(checks, {img: img_np})
@@ -3873,7 +3873,7 @@
img1_np = np.array([1, 2, 2, 1])
img2_np = np.array([1, 3, 3, 1])
- with self.test_session(use_gpu=True) as sess:
+ with self.cached_session(use_gpu=True) as sess:
_, _, checks = image_ops_impl._verify_compatible_image_shapes(img1, img2)
with self.assertRaises(errors.InvalidArgumentError):
sess.run(checks, {img1: img1_np, img2: img2_np})
@@ -3891,7 +3891,7 @@
return np.expand_dims(im, axis=0)
def _LoadTestImages(self):
- with self.test_session(use_gpu=True) as sess:
+ with self.cached_session(use_gpu=True) as sess:
q20 = self._LoadTestImage(sess, "cat_q20.jpg")
q72 = self._LoadTestImage(sess, "cat_q72.jpg")
q95 = self._LoadTestImage(sess, "cat_q95.jpg")
@@ -3912,7 +3912,7 @@
image2 = self._RandomImage((8, 8, 1), 1)
psnr = self._PSNR_NumPy(image1, image2, 1)
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
tf_image1 = constant_op.constant(image1, shape=image1.shape,
dtype=dtypes.float32)
tf_image2 = constant_op.constant(image2, shape=image2.shape,
@@ -3926,7 +3926,7 @@
image2 = self._RandomImage((10, 8, 8, 1), 1)
psnr = self._PSNR_NumPy(image1, image2, 1)
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
tf_image1 = constant_op.constant(image1, shape=image1.shape,
dtype=dtypes.float32)
tf_image2 = constant_op.constant(image2, shape=image2.shape,
@@ -3948,7 +3948,7 @@
self.assertNear(35.302, psnr3, 0.001)
# Test TensorFlow implementation.
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
tf_q20 = constant_op.constant(q20, shape=q20.shape, dtype=dtypes.float32)
tf_q72 = constant_op.constant(q72, shape=q72.shape, dtype=dtypes.float32)
tf_q95 = constant_op.constant(q95, shape=q95.shape, dtype=dtypes.float32)
@@ -3963,7 +3963,7 @@
def testInfinity(self):
q20, _, _ = self._LoadTestImages()
psnr = self._PSNR_NumPy(q20, q20, 1)
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
tf_q20 = constant_op.constant(q20, shape=q20.shape, dtype=dtypes.float32)
tf_psnr = image_ops.psnr(tf_q20, tf_q20, 1, "psnr").eval()
self.assertAllClose(psnr, tf_psnr, atol=0.001)
@@ -3978,7 +3978,7 @@
img1 = image_ops.convert_image_dtype(img1, dtypes.float32)
img2 = image_ops.convert_image_dtype(img2, dtypes.float32)
psnr_float32 = image_ops.psnr(img1, img2, 1.0)
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
self.assertAllClose(
psnr_uint8.eval(), self.evaluate(psnr_float32), atol=0.001)
@@ -4003,7 +4003,7 @@
return np.expand_dims(im, axis=0)
def _LoadTestImages(self):
- with self.test_session(use_gpu=True) as sess:
+ with self.cached_session(use_gpu=True) as sess:
return [self._LoadTestImage(sess, f) for f in self._filenames]
def _RandomImage(self, shape, max_val):
@@ -4018,7 +4018,7 @@
ph = [array_ops.placeholder(dtype=dtypes.float32) for _ in range(2)]
ssim = image_ops.ssim(*ph, max_val=1.0)
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
scores = [ssim.eval(dict(zip(ph, t)))
for t in itertools.combinations_with_replacement(img, 2)]
self.assertAllClose(expected, np.squeeze(scores), atol=1e-4)
@@ -4033,7 +4033,7 @@
ssim = image_ops.ssim(constant_op.constant(img1),
constant_op.constant(img2), 1.0)
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
self.assertAllClose(expected, self.evaluate(ssim), atol=1e-4)
def testBroadcast(self):
@@ -4045,7 +4045,7 @@
img2 = array_ops.expand_dims(img, axis=1) # batch dims: 2, 1.
ssim = image_ops.ssim(img1, img2, 1.0)
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
self.assertAllClose(expected, self.evaluate(ssim), atol=1e-4)
@test_util.run_deprecated_v1
@@ -4060,7 +4060,7 @@
ssim = image_ops.ssim(constant_op.constant(img1),
constant_op.constant(img2), 255)
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
self.assertLess(ssim.eval(), 0)
@test_util.run_deprecated_v1
@@ -4073,7 +4073,7 @@
img1 = image_ops.convert_image_dtype(img1, dtypes.float32)
img2 = image_ops.convert_image_dtype(img2, dtypes.float32)
ssim_float32 = image_ops.ssim(img1, img2, 1.0)
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
self.assertAllClose(
ssim_uint8.eval(), self.evaluate(ssim_float32), atol=0.001)
@@ -4098,7 +4098,7 @@
return np.expand_dims(im, axis=0)
def _LoadTestImages(self):
- with self.test_session(use_gpu=True) as sess:
+ with self.cached_session(use_gpu=True) as sess:
return [self._LoadTestImage(sess, f) for f in self._filenames]
def _RandomImage(self, shape, max_val):
@@ -4116,7 +4116,7 @@
ph = [array_ops.placeholder(dtype=dtypes.float32) for _ in range(2)]
msssim = image_ops.ssim_multiscale(*ph, max_val=1.0)
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
scores = [msssim.eval(dict(zip(ph, t)))
for t in itertools.combinations_with_replacement(img, 2)]
@@ -4131,7 +4131,7 @@
msssim = image_ops.ssim_multiscale(*scaled_ph, max_val=1.0,
power_factors=(1, 1, 1, 1, 1))
grads = gradients.gradients(msssim, scalar)
- with self.test_session(use_gpu=True) as sess:
+ with self.cached_session(use_gpu=True) as sess:
np_grads = sess.run(grads, feed_dict={ph[0]: img[0], ph[1]: img[1]})
self.assertTrue(np.isfinite(np_grads).all())
@@ -4146,7 +4146,7 @@
msssim = image_ops.ssim_multiscale(constant_op.constant(img1),
constant_op.constant(img2), 1.0)
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
self.assertAllClose(expected, self.evaluate(msssim), 1e-4)
def testBroadcast(self):
@@ -4159,7 +4159,7 @@
img2 = array_ops.expand_dims(img, axis=1) # batch dims: 2, 1.
score_tensor = image_ops.ssim_multiscale(img1, img2, 1.0)
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
self.assertAllClose(expected, self.evaluate(score_tensor), 1e-4)
def testRange(self):
@@ -4169,7 +4169,7 @@
If any of the value is negative so that the geometric mean is not
well-defined, then treat the MS-SSIM score as zero.
"""
- with self.test_session(use_gpu=True) as sess:
+ with self.cached_session(use_gpu=True) as sess:
img1 = self._LoadTestImage(sess, "checkerboard1.png")
img2 = self._LoadTestImage(sess, "checkerboard3.png")
images = [img1, img2, np.zeros_like(img1),
@@ -4194,7 +4194,7 @@
img1 = image_ops.convert_image_dtype(img1, dtypes.float32)
img2 = image_ops.convert_image_dtype(img2, dtypes.float32)
ssim_float32 = image_ops.ssim_multiscale(img1, img2, 1.0)
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
self.assertAllClose(
ssim_uint8.eval(), self.evaluate(ssim_float32), atol=0.001)
@@ -4235,7 +4235,7 @@
batch = constant_op.constant(batch)
assert batch.get_shape().as_list() == [2, 2, 3, 2]
dy, dx = image_ops.image_gradients(batch)
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
actual_dy = self.evaluate(dy)
actual_dx = self.evaluate(dx)
self.assertAllClose(expected_dy, actual_dy)
@@ -4256,7 +4256,7 @@
expected = np.reshape([[[0, 0], [0, 12], [0, 0]],
[[0, 0], [0, 12], [0, 0]]], [1, 2, 3, 1, 2])
sobel = image_ops.sobel_edges(img)
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
actual_sobel = self.evaluate(sobel)
self.assertAllClose(expected, actual_sobel)
@@ -4278,7 +4278,7 @@
expected_batch = np.concatenate([expected_two_channel] * batch_size, axis=0)
sobel = image_ops.sobel_edges(img)
- with self.test_session(use_gpu=True):
+ with self.cached_session(use_gpu=True):
actual_sobel = self.evaluate(sobel)
self.assertAllClose(expected_batch, actual_sobel)
@@ -4286,7 +4286,7 @@
class DecodeImageTest(test_util.TensorFlowTestCase):
def testJpegUint16(self):
- with self.test_session(use_gpu=True) as sess:
+ with self.cached_session(use_gpu=True) as sess:
base = "tensorflow/core/lib/jpeg/testdata"
jpeg0 = io_ops.read_file(os.path.join(base, "jpeg_merge_test1.jpg"))
image0 = image_ops.decode_image(jpeg0, dtype=dtypes.uint16)
@@ -4296,7 +4296,7 @@
self.assertAllEqual(image0, image1)
def testPngUint16(self):
- with self.test_session(use_gpu=True) as sess:
+ with self.cached_session(use_gpu=True) as sess:
base = "tensorflow/core/lib/png/testdata"
png0 = io_ops.read_file(os.path.join(base, "lena_rgba.png"))
image0 = image_ops.decode_image(png0, dtype=dtypes.uint16)
@@ -4306,7 +4306,7 @@
self.assertAllEqual(image0, image1)
def testGifUint16(self):
- with self.test_session(use_gpu=True) as sess:
+ with self.cached_session(use_gpu=True) as sess:
base = "tensorflow/core/lib/gif/testdata"
gif0 = io_ops.read_file(os.path.join(base, "scan.gif"))
image0 = image_ops.decode_image(gif0, dtype=dtypes.uint16)
@@ -4316,7 +4316,7 @@
self.assertAllEqual(image0, image1)
def testBmpUint16(self):
- with self.test_session(use_gpu=True) as sess:
+ with self.cached_session(use_gpu=True) as sess:
base = "tensorflow/core/lib/bmp/testdata"
bmp0 = io_ops.read_file(os.path.join(base, "lena.bmp"))
image0 = image_ops.decode_image(bmp0, dtype=dtypes.uint16)
@@ -4326,7 +4326,7 @@
self.assertAllEqual(image0, image1)
def testJpegFloat32(self):
- with self.test_session(use_gpu=True) as sess:
+ with self.cached_session(use_gpu=True) as sess:
base = "tensorflow/core/lib/jpeg/testdata"
jpeg0 = io_ops.read_file(os.path.join(base, "jpeg_merge_test1.jpg"))
image0 = image_ops.decode_image(jpeg0, dtype=dtypes.float32)
@@ -4336,7 +4336,7 @@
self.assertAllEqual(image0, image1)
def testPngFloat32(self):
- with self.test_session(use_gpu=True) as sess:
+ with self.cached_session(use_gpu=True) as sess:
base = "tensorflow/core/lib/png/testdata"
png0 = io_ops.read_file(os.path.join(base, "lena_rgba.png"))
image0 = image_ops.decode_image(png0, dtype=dtypes.float32)
@@ -4346,7 +4346,7 @@
self.assertAllEqual(image0, image1)
def testGifFloat32(self):
- with self.test_session(use_gpu=True) as sess:
+ with self.cached_session(use_gpu=True) as sess:
base = "tensorflow/core/lib/gif/testdata"
gif0 = io_ops.read_file(os.path.join(base, "scan.gif"))
image0 = image_ops.decode_image(gif0, dtype=dtypes.float32)
@@ -4356,7 +4356,7 @@
self.assertAllEqual(image0, image1)
def testBmpFloat32(self):
- with self.test_session(use_gpu=True) as sess:
+ with self.cached_session(use_gpu=True) as sess:
base = "tensorflow/core/lib/bmp/testdata"
bmp0 = io_ops.read_file(os.path.join(base, "lena.bmp"))
image0 = image_ops.decode_image(bmp0, dtype=dtypes.float32)
diff --git a/tensorflow/python/ops/linalg/adjoint_registrations.py b/tensorflow/python/ops/linalg/adjoint_registrations.py
new file mode 100644
index 0000000..59ec97d
--- /dev/null
+++ b/tensorflow/python/ops/linalg/adjoint_registrations.py
@@ -0,0 +1,127 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Registrations for LinearOperator.adjoint."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.linalg import linear_operator
+from tensorflow.python.ops.linalg import linear_operator_adjoint
+from tensorflow.python.ops.linalg import linear_operator_algebra
+from tensorflow.python.ops.linalg import linear_operator_block_diag
+from tensorflow.python.ops.linalg import linear_operator_circulant
+from tensorflow.python.ops.linalg import linear_operator_diag
+from tensorflow.python.ops.linalg import linear_operator_identity
+from tensorflow.python.ops.linalg import linear_operator_kronecker
+
+
+# By default, return LinearOperatorAdjoint which switched the .matmul
+# and .solve methods.
+@linear_operator_algebra.RegisterAdjoint(linear_operator.LinearOperator)
+def _adjoint_linear_operator(linop):
+ return linear_operator_adjoint.LinearOperatorAdjoint(
+ linop,
+ is_non_singular=linop.is_non_singular,
+ is_self_adjoint=linop.is_self_adjoint,
+ is_positive_definite=linop.is_positive_definite,
+ is_square=linop.is_square)
+
+
+@linear_operator_algebra.RegisterAdjoint(
+ linear_operator_adjoint.LinearOperatorAdjoint)
+def _adjoint_adjoint_linear_operator(linop):
+ return linop.operator
+
+
+@linear_operator_algebra.RegisterAdjoint(
+ linear_operator_identity.LinearOperatorIdentity)
+def _adjoint_identity(identity_operator):
+ return identity_operator
+
+
+@linear_operator_algebra.RegisterAdjoint(
+ linear_operator_identity.LinearOperatorScaledIdentity)
+def _adjoint_scaled_identity(identity_operator):
+ multiplier = identity_operator.multiplier
+ if multiplier.dtype.is_complex:
+ multiplier = math_ops.conj(multiplier)
+
+ return linear_operator_identity.LinearOperatorScaledIdentity(
+ num_rows=identity_operator._num_rows, # pylint: disable=protected-access
+ multiplier=multiplier,
+ is_non_singular=identity_operator.is_non_singular,
+ is_self_adjoint=identity_operator.is_self_adjoint,
+ is_positive_definite=identity_operator.is_positive_definite,
+ is_square=True)
+
+
+@linear_operator_algebra.RegisterAdjoint(
+ linear_operator_diag.LinearOperatorDiag)
+def _adjoint_diag(diag_operator):
+ diag = diag_operator.diag
+ if diag.dtype.is_complex:
+ diag = math_ops.conj(diag)
+
+ return linear_operator_diag.LinearOperatorDiag(
+ diag=diag,
+ is_non_singular=diag_operator.is_non_singular,
+ is_self_adjoint=diag_operator.is_self_adjoint,
+ is_positive_definite=diag_operator.is_positive_definite,
+ is_square=True)
+
+
+@linear_operator_algebra.RegisterAdjoint(
+ linear_operator_block_diag.LinearOperatorBlockDiag)
+def _adjoint_block_diag(block_diag_operator):
+ # We take the adjoint of each block on the diagonal.
+ return linear_operator_block_diag.LinearOperatorBlockDiag(
+ operators=[
+ operator.adjoint() for operator in block_diag_operator.operators],
+ is_non_singular=block_diag_operator.is_non_singular,
+ is_self_adjoint=block_diag_operator.is_self_adjoint,
+ is_positive_definite=block_diag_operator.is_positive_definite,
+ is_square=True)
+
+
+@linear_operator_algebra.RegisterAdjoint(
+ linear_operator_kronecker.LinearOperatorKronecker)
+def _adjoint_kronecker(kronecker_operator):
+ # Adjoint of a Kronecker product is the Kronecker product
+ # of adjoints.
+ return linear_operator_kronecker.LinearOperatorKronecker(
+ operators=[
+ operator.adjoint() for operator in kronecker_operator.operators],
+ is_non_singular=kronecker_operator.is_non_singular,
+ is_self_adjoint=kronecker_operator.is_self_adjoint,
+ is_positive_definite=kronecker_operator.is_positive_definite,
+ is_square=True)
+
+
+@linear_operator_algebra.RegisterAdjoint(
+ linear_operator_circulant.LinearOperatorCirculant)
+def _adjoint_circulant(circulant_operator):
+ spectrum = circulant_operator.spectrum
+ if spectrum.dtype.is_complex:
+ spectrum = math_ops.conj(spectrum)
+
+ # Conjugating the spectrum is sufficient to get the adjoint.
+ return linear_operator_circulant.LinearOperatorCirculant(
+ spectrum=spectrum,
+ is_non_singular=circulant_operator.is_non_singular,
+ is_self_adjoint=circulant_operator.is_self_adjoint,
+ is_positive_definite=circulant_operator.is_positive_definite,
+ is_square=True)
diff --git a/tensorflow/python/ops/linalg/linalg.py b/tensorflow/python/ops/linalg/linalg.py
index eebe741..b9f8411 100644
--- a/tensorflow/python/ops/linalg/linalg.py
+++ b/tensorflow/python/ops/linalg/linalg.py
@@ -20,6 +20,7 @@
# go/tf-wildcard-import
# pylint: disable=wildcard-import,unused-import
+from tensorflow.python.ops.linalg import adjoint_registrations as _adjoint_registrations
from tensorflow.python.ops.linalg import cholesky_registrations as _cholesky_registrations
from tensorflow.python.ops.linalg import inverse_registrations as _inverse_registrations
from tensorflow.python.ops.linalg import linear_operator_algebra as _linear_operator_algebra
diff --git a/tensorflow/python/ops/linalg/linear_operator.py b/tensorflow/python/ops/linalg/linear_operator.py
index 4c99e86..8fa9f63 100644
--- a/tensorflow/python/ops/linalg/linear_operator.py
+++ b/tensorflow/python/ops/linalg/linear_operator.py
@@ -847,6 +847,26 @@
return self._solvevec(rhs, adjoint=adjoint)
+ def adjoint(self, name="adjoint"):
+ """Returns the adjoint of the current `LinearOperator`.
+
+ Given `A` representing this `LinearOperator`, return `A*`.
+ Note that calling `self.adjoint()` and `self.H` are equivalent.
+
+ Args:
+ name: A name for this `Op`.
+
+ Returns:
+ `LinearOperator` which represents the adjoint of this `LinearOperator`.
+ """
+ if self.is_self_adjoint is True: # pylint: disable=g-bool-id-comparison
+ return self
+ with self._name_scope(name):
+ return linear_operator_algebra.adjoint(self)
+
+ # self.H is equivalent to self.adjoint().
+ H = property(adjoint, None)
+
def inverse(self, name="inverse"):
"""Returns the Inverse of this `LinearOperator`.
diff --git a/tensorflow/python/ops/linalg/linear_operator_adjoint.py b/tensorflow/python/ops/linalg/linear_operator_adjoint.py
index 858e224..7ee4752 100644
--- a/tensorflow/python/ops/linalg/linear_operator_adjoint.py
+++ b/tensorflow/python/ops/linalg/linear_operator_adjoint.py
@@ -19,6 +19,7 @@
from __future__ import print_function
from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.linalg import linalg_impl as linalg
from tensorflow.python.ops.linalg import linear_operator
@@ -83,7 +84,7 @@
r"""Initialize a `LinearOperatorAdjoint`.
`LinearOperatorAdjoint` is initialized with an operator `A`. The `solve`
- and `matmul` methods effectively flip the `adjoint` argument. E.g.
+ and `matmul` methods effectively flip the `adjoint` argument. E.g.
```
A = MyLinearOperator(...)
@@ -175,15 +176,24 @@
return self.operator.assert_self_adjoint()
def _shape(self):
- return self.operator.shape
+ # Rotate last dimension
+ shape = self.operator.shape
+ return shape[:-2].concatenate([shape[-1], shape[-2]])
def _shape_tensor(self):
- return self.operator.shape_tensor()
+ # Rotate last dimension
+ shape = self.operator.shape_tensor()
+ return array_ops.concat([
+ shape[:-2], [shape[-1], shape[-2]]], axis=-1)
def _matmul(self, x, adjoint=False, adjoint_arg=False):
return self.operator.matmul(
x, adjoint=(not adjoint), adjoint_arg=adjoint_arg)
+ def _matvec(self, x, adjoint=False, adjoint_arg=False):
+ return self.operator.matvec(
+ x, adjoint=(not adjoint), adjoint_arg=adjoint_arg)
+
def _determinant(self):
if self.is_self_adjoint:
return self.operator.determinant()
@@ -201,7 +211,14 @@
return self.operator.solve(
rhs, adjoint=(not adjoint), adjoint_arg=adjoint_arg)
+ def _solvevec(self, rhs, adjoint=False, adjoint_arg=False):
+ return self.operator.solvevec(
+ rhs, adjoint=(not adjoint), adjoint_arg=adjoint_arg)
+
def _to_dense(self):
if self.is_self_adjoint:
return self.operator.to_dense()
return linalg.adjoint(self.operator.to_dense())
+
+ def _add_to_tensor(self, x):
+ return self.to_dense() + x
diff --git a/tensorflow/python/ops/linalg/linear_operator_algebra.py b/tensorflow/python/ops/linalg/linear_operator_algebra.py
index c1513fd..0d1eab4 100644
--- a/tensorflow/python/ops/linalg/linear_operator_algebra.py
+++ b/tensorflow/python/ops/linalg/linear_operator_algebra.py
@@ -25,6 +25,7 @@
from tensorflow.python.util import tf_inspect
+_ADJOINTS = {}
_CHOLESKY_DECOMPS = {}
_MATMUL = {}
_INVERSES = {}
@@ -46,6 +47,11 @@
return registry.get(tuple(r[1] for r in registered_combination), None)
+def _registered_adjoint(type_a):
+ """Get the Adjoint function registered for class a."""
+ return _registered_function([type_a], _ADJOINTS)
+
+
def _registered_cholesky(type_a):
"""Get the Cholesky function registered for class a."""
return _registered_function([type_a], _CHOLESKY_DECOMPS)
@@ -61,6 +67,29 @@
return _registered_function([type_a], _INVERSES)
+def adjoint(lin_op_a, name=None):
+ """Get the adjoint associated to lin_op_a.
+
+ Args:
+ lin_op_a: The LinearOperator to take the adjoint of.
+ name: Name to use for this operation.
+
+ Returns:
+ A LinearOperator that represents the adjoint of `lin_op_a`.
+
+ Raises:
+ NotImplementedError: If no Adjoint method is defined for the LinearOperator
+ type of `lin_op_a`.
+ """
+ adjoint_fn = _registered_adjoint(type(lin_op_a))
+ if adjoint_fn is None:
+ raise ValueError("No adjoint registered for {}".format(
+ type(lin_op_a)))
+
+ with ops.name_scope(name, "Adjoint"):
+ return adjoint_fn(lin_op_a)
+
+
def cholesky(lin_op_a, name=None):
"""Get the Cholesky factor associated to lin_op_a.
@@ -132,6 +161,48 @@
return inverse_fn(lin_op_a)
+class RegisterAdjoint(object):
+ """Decorator to register an Adjoint implementation function.
+
+ Usage:
+
+ @linear_operator_algebra.RegisterAdjoint(lin_op.LinearOperatorIdentity)
+ def _adjoint_identity(lin_op_a):
+ # Return the identity matrix.
+ """
+
+ def __init__(self, lin_op_cls_a):
+ """Initialize the LinearOperator registrar.
+
+ Args:
+ lin_op_cls_a: the class of the LinearOperator to decompose.
+ """
+ self._key = (lin_op_cls_a,)
+
+ def __call__(self, adjoint_fn):
+ """Perform the Adjoint registration.
+
+ Args:
+ adjoint_fn: The function to use for the Adjoint.
+
+ Returns:
+ adjoint_fn
+
+ Raises:
+ TypeError: if adjoint_fn is not a callable.
+ ValueError: if a Adjoint function has already been registered for
+ the given argument classes.
+ """
+ if not callable(adjoint_fn):
+ raise TypeError(
+ "adjoint_fn must be callable, received: {}".format(adjoint_fn))
+ if self._key in _ADJOINTS:
+ raise ValueError("Adjoint({}) has already been registered to: {}".format(
+ self._key[0].__name__, _ADJOINTS[self._key]))
+ _ADJOINTS[self._key] = adjoint_fn
+ return adjoint_fn
+
+
class RegisterCholesky(object):
"""Decorator to register a Cholesky implementation function.
diff --git a/tensorflow/python/ops/linalg/linear_operator_test_util.py b/tensorflow/python/ops/linalg/linear_operator_test_util.py
index a957c84..0383098 100644
--- a/tensorflow/python/ops/linalg/linear_operator_test_util.py
+++ b/tensorflow/python/ops/linalg/linear_operator_test_util.py
@@ -278,6 +278,23 @@
self._skip_if_tests_to_skip_contains("matmul_with_broadcast")
self._test_matmul(with_batch=False)
+ def test_adjoint(self):
+ self._skip_if_tests_to_skip_contains("adjoint")
+ for use_placeholder in self._use_placeholder_options:
+ for build_info in self._operator_build_infos:
+ for dtype in self._dtypes_to_test:
+ with self.test_session(graph=ops.Graph()) as sess:
+ sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
+ operator, mat = self._operator_and_matrix(
+ build_info, dtype, use_placeholder=use_placeholder)
+ op_adjoint = operator.adjoint().to_dense()
+ op_adjoint_h = operator.H.to_dense()
+ mat_adjoint = linalg.adjoint(mat)
+ op_adjoint_v, op_adjoint_h_v, mat_adjoint_v = sess.run(
+ [op_adjoint, op_adjoint_h, mat_adjoint])
+ self.assertAC(mat_adjoint_v, op_adjoint_v)
+ self.assertAC(mat_adjoint_v, op_adjoint_h_v)
+
def test_cholesky(self):
self._skip_if_tests_to_skip_contains("cholesky")
for use_placeholder in self._use_placeholder_options:
diff --git a/tensorflow/python/ops/list_ops.py b/tensorflow/python/ops/list_ops.py
index fcd5b81..87409eb 100644
--- a/tensorflow/python/ops/list_ops.py
+++ b/tensorflow/python/ops/list_ops.py
@@ -71,11 +71,12 @@
name=name)
-def tensor_list_get_item(input_handle, index, element_dtype, name=None):
+def tensor_list_get_item(input_handle, index, element_dtype, element_shape=None,
+ name=None):
return gen_list_ops.tensor_list_get_item(
input_handle=input_handle,
index=index,
- element_shape=-1,
+ element_shape=_build_element_shape(element_shape),
element_dtype=element_dtype,
name=name)
@@ -119,9 +120,12 @@
name=None):
# Ignore the lengths output of TensorListConcat. It is only used during
# gradient computation.
- return gen_list_ops.tensor_list_concat(
- input_handle=input_handle, element_dtype=element_dtype,
- element_shape=element_shape, name=name)[0]
+ return gen_list_ops.tensor_list_concat_v2(
+ input_handle=input_handle,
+ element_dtype=element_dtype,
+ element_shape=_build_element_shape(element_shape),
+ leading_dims=ops.convert_to_tensor([], dtype=dtypes.int64),
+ name=name)[0]
def tensor_list_split(tensor, element_shape, lengths, name=None):
@@ -175,22 +179,30 @@
@ops.RegisterGradient("TensorListConcat")
+@ops.RegisterGradient("TensorListConcatV2")
def _TensorListConcatGrad(op, dtensor, unused_dlengths):
- # TODO(srbs): We lose the element_shape information in tensor_list_concat.
- # Consider providing that as an output of TensorListConcat?
- if dtensor.shape.rank is None:
- element_shape = None
- else:
- element_shape = [None] + dtensor.shape.as_list()[1:]
- return tensor_list_split(
+ """Gradient function for TensorListConcat."""
+ dlist = tensor_list_split(
dtensor,
- element_shape=_build_element_shape(element_shape),
+ element_shape=gen_list_ops.tensor_list_element_shape(
+ op.inputs[0], shape_type=dtypes.int32),
lengths=op.outputs[1])
+ if op.type == "TensorListConcatV2":
+ return dlist, None, None
+ else:
+ return dlist
@ops.RegisterGradient("TensorListSplit")
def _TensorListSplitGrad(op, dlist):
- return tensor_list_concat(dlist, element_dtype=op.inputs[0].dtype), None, None
+ tensor, _, lengths = op.inputs
+ element_shape = array_ops.slice(array_ops.shape(tensor), [1], [-1])
+ element_shape = array_ops.concat([[-1], element_shape], axis=0)
+ return gen_list_ops.tensor_list_concat_v2(
+ dlist,
+ element_shape=element_shape,
+ leading_dims=lengths,
+ element_dtype=op.inputs[0].dtype)[0], None, None
@ops.RegisterGradient("TensorListFromTensor")
@@ -238,7 +250,7 @@
list_grad = gen_list_ops.tensor_list_set_item(
dlist, index=index, item=array_ops.zeros_like(item))
index_grad = None
- element_grad = gen_list_ops.tensor_list_get_item(
+ element_grad = tensor_list_get_item(
dlist,
index,
element_shape=array_ops.shape(item),
@@ -317,4 +329,13 @@
if not shape:
return ops.convert_to_tensor(shape, dtype=dtypes.int32)
# Shape is a sequence of dimensions. Convert None dims to -1.
- return [d if d is not None else -1 for d in shape]
+ def convert(val):
+ if val is None:
+ return -1
+ if isinstance(val, ops.Tensor):
+ return val
+ if isinstance(val, tensor_shape.Dimension):
+ return val.value if val.value is not None else -1
+ return val
+
+ return [convert(d) for d in shape]
diff --git a/tensorflow/python/ops/logging_ops.py b/tensorflow/python/ops/logging_ops.py
index 3cb16eb..f05fbf4 100644
--- a/tensorflow/python/ops/logging_ops.py
+++ b/tensorflow/python/ops/logging_ops.py
@@ -25,6 +25,7 @@
import six
+from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
@@ -40,6 +41,14 @@
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
+# Register printing to the cell output if we are in a Colab or Jupyter Notebook.
+try:
+ get_ipython() # Exists in an ipython env like Jupyter or Colab
+ pywrap_tensorflow.TFE_Py_EnableInteractivePythonLogging()
+except NameError:
+ pass
+
+
# The python wrapper for Assert is in control_flow_ops, as the Assert
# call relies on certain conditionals for its dependencies. Use
# control_flow_ops.Assert.
@@ -193,9 +202,8 @@
(This prints "tensors: [0 1 2 ... 7 8 9] {2: [0 2 4 ... 14 16 18]}" to
sys.stdout)
- Note: This op is only partially compatible with Jupyter notebooks and colabs.
- Because it prints to the C++ standard out / standard error, this will go
- in the notebook kernel's console output, not in the notebook cell output.
+ Note: In Jupyter notebooks and colabs, this operator prints to the notebook
+ cell outputs. It will not write to the notebook kernel's console logs.
Args:
*inputs: Positional arguments that are the inputs to print. Inputs in the
diff --git a/tensorflow/python/ops/losses/BUILD b/tensorflow/python/ops/losses/BUILD
index 4aea026..9155d89 100644
--- a/tensorflow/python/ops/losses/BUILD
+++ b/tensorflow/python/ops/losses/BUILD
@@ -29,6 +29,7 @@
"//tensorflow/python:platform",
"//tensorflow/python:util",
"//tensorflow/python:weights_broadcast_ops",
+ "//tensorflow/python/distribute:distribute_lib",
],
)
diff --git a/tensorflow/python/ops/losses/losses_impl.py b/tensorflow/python/ops/losses/losses_impl.py
index 3393e75..1169c45 100644
--- a/tensorflow/python/ops/losses/losses_impl.py
+++ b/tensorflow/python/ops/losses/losses_impl.py
@@ -18,6 +18,7 @@
from __future__ import division
from __future__ import print_function
+from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -43,6 +44,8 @@
* `NONE`: Un-reduced weighted losses with the same shape as input.
* `SUM`: Scalar sum of weighted losses.
* `SUM_OVER_BATCH_SIZE`: Scalar `SUM` divided by number of elements in losses.
+ Note that when using `tf.distribute.Strategy`, this is the global batch
+ size across all the replicas that are contributing to a single step.
"""
NONE = "none"
@@ -69,8 +72,13 @@
* `SUM`: Scalar sum of weighted losses.
* `MEAN`: Scalar `SUM` divided by sum of weights. DEPRECATED.
* `SUM_OVER_BATCH_SIZE`: Scalar `SUM` divided by number of elements in losses.
+ Note that when using `tf.distribute.Strategy`, this is the global batch
+ size across all the replicas that are contributing to a single step.
* `SUM_OVER_NONZERO_WEIGHTS`: Scalar `SUM` divided by number of non-zero
weights. DEPRECATED.
+ Note that when using `tf.distribute.Strategy`, this is scaled by the
+ number of replicas that are contributing to a single step to get an
+ approximation to the global batch size.
* `SUM_BY_NONZERO_WEIGHTS`: Same as `SUM_OVER_NONZERO_WEIGHTS`.
"""
@@ -198,11 +206,6 @@
"""
Reduction.validate(reduction)
with ops.name_scope(scope, "weighted_loss", (losses, weights)):
- # Save the `reduction` argument for loss normalization when distributing
- # to multiple replicas.
- # TODO(josh11b): Associate it with the returned op for more precision.
- ops.get_default_graph()._last_loss_reduction = reduction # pylint: disable=protected-access
-
with ops.control_dependencies((
weights_broadcast_ops.assert_broadcastable(weights, losses),)):
losses = ops.convert_to_tensor(losses)
@@ -214,15 +217,17 @@
loss = weighted_losses
else:
loss = math_ops.reduce_sum(weighted_losses)
+ num_replicas = ( # Used to convert from local to global batch size.
+ distribution_strategy_context.get_strategy().num_replicas_in_sync)
if reduction == Reduction.MEAN:
- loss = _safe_mean(
- loss,
- math_ops.reduce_sum(array_ops.ones_like(losses) * weights))
+ denom = (num_replicas *
+ math_ops.reduce_sum(array_ops.ones_like(losses) * weights))
+ loss = _safe_mean(loss, denom)
elif (reduction == Reduction.SUM_BY_NONZERO_WEIGHTS or
reduction == Reduction.SUM_OVER_NONZERO_WEIGHTS):
- loss = _safe_mean(loss, _num_present(losses, weights))
+ loss = _safe_mean(loss, num_replicas * _num_present(losses, weights))
elif reduction == Reduction.SUM_OVER_BATCH_SIZE:
- loss = _safe_mean(loss, _num_elements(losses))
+ loss = _safe_mean(loss, num_replicas * _num_elements(losses))
# Convert the result back to the input type.
loss = math_ops.cast(loss, input_dtype)
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index c620712..7306a45 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -12,9 +12,59 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Basic arithmetic operators.
+"""Math Operations.
-See the [python/math_ops](python/math_ops) guide.
+Note: Functions taking `Tensor` arguments can also take anything accepted by
+`tf.convert_to_tensor`.
+
+Note: Elementwise binary operations in TensorFlow follow [numpy-style
+broadcasting](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
+
+TensorFlow provides a variety of math functions including:
+
+* Basic arithmetic operators and trigonometric functions.
+* Special math functions (like: `tf.math.igamma` and `tf.math.zeta`)
+* Complex number functions (like: `tf.math.imag` and `tf.math.angle`)
+* Reductions and scans (like: `tf.math.reduce_mean` and `tf.math.cumsum`)
+* Segment functions (like: `tf.math.segment_sum`)
+
+See: `tf.linalg` for matrix and tensor functions.
+
+<a id=Segmentation></a>
+
+## About Segmentation
+
+TensorFlow provides several operations that you can use to perform common
+math computations on tensor segments.
+Here a segmentation is a partitioning of a tensor along
+the first dimension, i.e. it defines a mapping from the first dimension onto
+`segment_ids`. The `segment_ids` tensor should be the size of
+the first dimension, `d0`, with consecutive IDs in the range `0` to `k`,
+where `k<d0`.
+In particular, a segmentation of a matrix tensor is a mapping of rows to
+segments.
+
+For example:
+
+```python
+c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
+tf.segment_sum(c, tf.constant([0, 0, 1]))
+# ==> [[0 0 0 0]
+# [5 6 7 8]]
+```
+
+The standard `segment_*` functions assert that the segment indices are sorted.
+If you have unsorted indices use the equivalent `unsorted_segment_` function.
+Thses functions take an additional argument `num_segments` so that the output
+tensor can be efficiently allocated.
+
+``` python
+c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
+tf.unsorted_segment_sum(c, tf.constant([0, 1, 0]), num_segments=2)
+# ==> [[ 6, 8, 10, 12],
+# [-1, -2, -3, -4]]
+```
+
"""
from __future__ import absolute_import
from __future__ import division
@@ -3112,7 +3162,7 @@
r"""Computes the mean along segments of a tensor.
Read [the section on
- segmentation](https://tensorflow.org/api_guides/python/math_ops#segmentation)
+ segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
for an explanation of segments.
This operator is similar to the unsorted segment sum operator found
@@ -3158,7 +3208,7 @@
r"""Computes the sum along segments of a tensor divided by the sqrt(N).
Read [the section on
- segmentation](https://tensorflow.org/api_guides/python/math_ops#segmentation)
+ segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
for an explanation of segments.
This operator is similar to the unsorted segment sum operator found
@@ -3205,7 +3255,7 @@
r"""Computes the sum along sparse segments of a tensor.
Read [the section on
- segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+ segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
for an explanation of segments.
Like `SegmentSum`, but `segment_ids` can have rank less than `data`'s first
@@ -3292,7 +3342,7 @@
r"""Computes the mean along sparse segments of a tensor.
Read [the section on
- segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+ segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
for an explanation of segments.
Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first
@@ -3337,7 +3387,7 @@
r"""Computes the mean along sparse segments of a tensor.
Read [the section on
- segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+ segmentation](https://tensorflow.org/api_docs/python/tf/math#Segmentation)
for an explanation of segments.
Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index 6ca2b2a..a3d3c7b 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -50,7 +50,7 @@
strides=op.get_attr("strides"),
padding=op.get_attr("padding"),
use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"),
- data_format=op.get_attr("data_format")),
+ data_format=op.get_attr("data_format").decode()),
nn_ops.conv2d(
grad,
op.inputs[1],
@@ -58,7 +58,7 @@
strides=op.get_attr("strides"),
padding=op.get_attr("padding"),
use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"),
- data_format=op.get_attr("data_format"))
+ data_format=op.get_attr("data_format").decode())
]
@@ -73,7 +73,7 @@
strides=op.get_attr("strides"),
padding=op.get_attr("padding"),
use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"),
- data_format=op.get_attr("data_format")), None,
+ data_format=op.get_attr("data_format").decode()), None,
nn_ops.conv2d(
op.inputs[0],
grad,
@@ -81,13 +81,13 @@
strides=op.get_attr("strides"),
padding=op.get_attr("padding"),
use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"),
- data_format=op.get_attr("data_format"))
+ data_format=op.get_attr("data_format").decode())
]
@ops.RegisterGradient("Conv3D")
def _Conv3DGrad(op, grad):
- data_format = op.get_attr("data_format")
+ data_format = op.get_attr("data_format").decode()
return [
nn_ops.conv3d_backprop_input_v2(
array_ops.shape(op.inputs[0]),
@@ -110,7 +110,7 @@
@ops.RegisterGradient("Conv3DBackpropInputV2")
def _Conv3DBackpropInputGrad(op, grad):
- data_format = op.get_attr("data_format")
+ data_format = op.get_attr("data_format").decode()
return [
None,
nn_ops.conv3d_backprop_filter_v2(
@@ -133,7 +133,7 @@
@ops.RegisterGradient("Conv3DBackpropFilterV2")
def _Conv3DBackpropFilterGrad(op, grad):
- data_format = op.get_attr("data_format")
+ data_format = op.get_attr("data_format").decode()
return [
nn_ops.conv3d_backprop_input_v2(
array_ops.shape(op.inputs[0]),
@@ -161,7 +161,7 @@
ksize=op.get_attr("ksize"),
strides=op.get_attr("strides"),
padding=op.get_attr("padding"),
- data_format=op.get_attr("data_format"))
+ data_format=op.get_attr("data_format").decode())
@ops.RegisterGradient("AvgPool3DGrad")
@@ -172,7 +172,7 @@
op.get_attr("ksize"),
op.get_attr("strides"),
op.get_attr("padding"),
- data_format=op.get_attr("data_format")))
+ data_format=op.get_attr("data_format").decode()))
@ops.RegisterGradient("MaxPool3D")
@@ -184,7 +184,7 @@
ksize=op.get_attr("ksize"),
strides=op.get_attr("strides"),
padding=op.get_attr("padding"),
- data_format=op.get_attr("data_format"))
+ data_format=op.get_attr("data_format").decode())
@ops.RegisterGradient("MaxPool3DGrad")
@@ -200,7 +200,7 @@
op.get_attr("ksize"),
op.get_attr("strides"),
padding=op.get_attr("padding"),
- data_format=op.get_attr("data_format")))
+ data_format=op.get_attr("data_format").decode()))
@ops.RegisterGradient("MaxPool3DGradGrad")
@@ -216,7 +216,7 @@
op.get_attr("ksize"),
op.get_attr("strides"),
padding=op.get_attr("padding"),
- data_format=op.get_attr("data_format")))
+ data_format=op.get_attr("data_format").decode()))
@ops.RegisterGradient("Softmax")
diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py
index 1853323..dc252c7 100644
--- a/tensorflow/python/ops/nn_impl.py
+++ b/tensorflow/python/ops/nn_impl.py
@@ -757,7 +757,7 @@
@tf_export(v1=["nn.sufficient_statistics"])
-def sufficient_statistics(x, axes, shift=None, keep_dims=False, name=None,
+def sufficient_statistics(x, axes, shift=None, keep_dims=None, name=None,
keepdims=None):
"""Calculate the sufficient statistics for the mean and variance of `x`.
@@ -786,6 +786,8 @@
axes = list(set(axes))
keep_dims = deprecated_argument_lookup(
"keepdims", keepdims, "keep_dims", keep_dims)
+ if keep_dims is None:
+ keep_dims = False
with ops.name_scope(name, "sufficient_statistics", [x, shift]):
x = ops.convert_to_tensor(x, name="x")
x_shape = x.get_shape()
@@ -877,7 +879,7 @@
axes,
shift=None, # pylint: disable=unused-argument
name=None,
- keep_dims=False,
+ keep_dims=None,
keepdims=None):
"""Calculate the mean and variance of `x`.
@@ -908,6 +910,8 @@
"""
keep_dims = deprecated_argument_lookup(
"keepdims", keepdims, "keep_dims", keep_dims)
+ if keep_dims is None:
+ keep_dims = False
with ops.name_scope(name, "moments", [x, axes]):
# The dynamic range of fp16 is too limited to support the collection of
# sufficient statistics. As a workaround we simply perform the operations
@@ -971,7 +975,7 @@
@tf_export(v1=["nn.weighted_moments"])
-def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=False,
+def weighted_moments(x, axes, frequency_weights, name=None, keep_dims=None,
keepdims=None):
"""Returns the frequency-weighted mean and variance of `x`.
@@ -990,6 +994,8 @@
"""
keep_dims = deprecated_argument_lookup(
"keepdims", keepdims, "keep_dims", keep_dims)
+ if keep_dims is None:
+ keep_dims = False
with ops.name_scope(name, "weighted_moments", [x, frequency_weights, axes]):
x = ops.convert_to_tensor(x, name="x")
frequency_weights = ops.convert_to_tensor(
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index fd1173b..bd70d9f 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -1503,10 +1503,11 @@
filters: A `Tensor`. Must have the same type as `input`.
A 4-D tensor of shape
`[filter_height, filter_width, in_channels, out_channels]`
- strides: A list of `ints`.
- 1-D tensor of length 4. The stride of the sliding window for each
- dimension of `input`. The dimension order is determined by the value of
- `data_format`, see below for details.
+ strides: An int or list of `ints` that has length `1`, `2` or `4`. The
+ stride of the sliding window for each dimension of `input`. If a single
+ value is given it is replicated in the `H` and `W` dimension. By default
+ the `N` and `C` dimensions are set to 1. The dimension order is determined
+ by the value of `data_format`, see below for details.
padding: Either the `string `"SAME"` or `"VALID"` indicating the type of
padding algorithm to use, or a list indicating the explicit paddings at
the start and end of each dimension. When explicit padding is used and
@@ -1521,20 +1522,20 @@
[batch, height, width, channels].
Alternatively, the format could be "NCHW", the data storage order of:
[batch, channels, height, width].
- dilations: An optional list of `ints`. Defaults to `[1, 1, 1, 1]`.
- 1-D tensor of length 4. The dilation factor for each dimension of
- `input`. If set to k > 1, there will be k-1 skipped cells between each
- filter element on that dimension. The dimension order is determined by the
- value of `data_format`, see above for details. Dilations in the batch and
- depth dimensions must be 1.
+ dilations: An int or list of `ints` that has length `1`, `2` or `4`,
+ defaults to 1. The dilation factor for each dimension of`input`. If a
+ single value is given it is replicated in the `H` and `W` dimension. By
+ default the `N` and `C` dimensions are set to 1. If set to k > 1, there
+ will be k-1 skipped cells between each filter element on that dimension.
+ The dimension order is determined by the value of `data_format`, see above
+ for details. Dilations in the batch and depth dimensions if a 4-d tensor
+ must be 1.
name: A name for the operation (optional).
Returns:
A `Tensor`. Has the same type as `input`.
"""
# pylint: enable=line-too-long
- if dilations is None:
- dilations = [1, 1, 1, 1]
return conv2d(input, # pylint: disable=redefined-builtin
filters,
strides,
@@ -1588,10 +1589,11 @@
filter: A `Tensor`. Must have the same type as `input`.
A 4-D tensor of shape
`[filter_height, filter_width, in_channels, out_channels]`
- strides: A list of `ints`.
- 1-D tensor of length 4. The stride of the sliding window for each
- dimension of `input`. The dimension order is determined by the value of
- `data_format`, see below for details.
+ strides: An int or list of `ints` that has length `1`, `2` or `4`. The
+ stride of the sliding window for each dimension of `input`. If a single
+ value is given it is replicated in the `H` and `W` dimension. By default
+ the `N` and `C` dimensions are set to 1. The dimension order is determined
+ by the value of `data_format`, see below for details.
padding: Either the `string `"SAME"` or `"VALID"` indicating the type of
padding algorithm to use, or a list indicating the explicit paddings at
the start and end of each dimension. When explicit padding is used and
@@ -1607,12 +1609,14 @@
[batch, height, width, channels].
Alternatively, the format could be "NCHW", the data storage order of:
[batch, channels, height, width].
- dilations: An optional list of `ints`. Defaults to `[1, 1, 1, 1]`.
- 1-D tensor of length 4. The dilation factor for each dimension of
- `input`. If set to k > 1, there will be k-1 skipped cells between each
- filter element on that dimension. The dimension order is determined by the
- value of `data_format`, see above for details. Dilations in the batch and
- depth dimensions must be 1.
+ dilations: An int or list of `ints` that has length `1`, `2` or `4`,
+ defaults to 1. The dilation factor for each dimension of`input`. If a
+ single value is given it is replicated in the `H` and `W` dimension. By
+ default the `N` and `C` dimensions are set to 1. If set to k > 1, there
+ will be k-1 skipped cells between each filter element on that dimension.
+ The dimension order is determined by the value of `data_format`, see above
+ for details. Dilations in the batch and depth dimensions if a 4-d tensor
+ must be 1.
name: A name for the operation (optional).
filters: Alias for filter.
@@ -1622,6 +1626,12 @@
filter = deprecation.deprecated_argument_lookup(
"filters", filters, "filter", filter)
padding, explicit_paddings = _convert_padding(padding)
+ if data_format is None:
+ data_format = "NHWC"
+ channel_index = 1 if data_format.startswith("NC") else 3
+
+ strides = _get_sequence(strides, 2, channel_index, "strides")
+ dilations = _get_sequence(dilations, 2, channel_index, "dilations")
return gen_nn_ops.conv2d(input, # pylint: disable=redefined-builtin
filter,
strides,
@@ -1912,8 +1922,11 @@
`in_channels` dimension must match that of `value`.
output_shape: A 1-D `Tensor` representing the output shape of the
deconvolution op.
- strides: A list of ints. The stride of the sliding window for each
- dimension of the input tensor.
+ strides: An int or list of `ints` that has length `1`, `2` or `4`. The
+ stride of the sliding window for each dimension of `input`. If a single
+ value is given it is replicated in the `H` and `W` dimension. By default
+ the `N` and `C` dimensions are set to 0. The dimension order is determined
+ by the value of `data_format`, see below for details.
padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
See the "returns" section of `tf.nn.convolution` for details.
data_format: A string. 'NHWC' and 'NCHW' are supported.
@@ -1961,6 +1974,8 @@
raise ValueError("padding must be either VALID or SAME:"
" {}".format(padding))
+ strides = _get_sequence(strides, 2, axis, "strides")
+
return gen_nn_ops.conv2d_backprop_input(
input_sizes=output_shape_,
filter=filter,
@@ -3061,8 +3076,77 @@
name=name)
-@tf_export("nn.max_pool")
-def max_pool(value, ksize, strides, padding, data_format="NHWC", name=None):
+# pylint: disable=redefined-builtin
+@tf_export("nn.max_pool", v1=["nn.max_pool_v2"])
+def max_pool_v2(input, ksize, strides, padding, data_format=None, name=None):
+ """Performs the max pooling on the input.
+
+ Args:
+ input: Tensor of rank N+2, of shape `[batch_size] + input_spatial_shape +
+ [num_channels]` if `data_format` does not start with "NC" (default), or
+ `[batch_size, num_channels] + input_spatial_shape` if data_format starts
+ with "NC". Pooling happens over the spatial dimensions only.
+ ksize: An int or list of `ints` that has length `1`, `N` or `N+2`. The size
+ of the window for each dimension of the input tensor.
+ strides: An int or list of `ints` that has length `1`, `N` or `N+2`. The
+ stride of the sliding window for each dimension of the input tensor.
+ padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. See
+ the "returns" section of `tf.nn.convolution` for details.
+ data_format: A string. Specifies the channel dimension. For N=1 it can be
+ either "NWC" (default) or "NCW", for N=2 it can be either "NHWC" (default)
+ or "NCHW" and for N=3 either "NDHWC" (default) or "NCDHW".
+ name: Optional name for the operation.
+
+ Returns:
+ A `Tensor` of format specified by `data_format`.
+ The max pooled output tensor.
+ """
+ if input.shape is not None:
+ n = len(input.shape) - 2
+ elif data_format is not None:
+ n = len(data_format) - 2
+ else:
+ raise ValueError(
+ "The input must have a rank or a data format must be given.")
+ if n < 1 or n > 3:
+ raise ValueError(
+ "Input tensor must be of rank 3, 4 or 5 but was {}.".format(n + 2))
+
+ if data_format is None:
+ channel_index = n + 1
+ else:
+ channel_index = 1 if data_format.startswith("NC") else n + 1
+
+ ksize = _get_sequence(ksize, n, channel_index, "ksize")
+ strides = _get_sequence(strides, n, channel_index, "strides")
+
+ max_pooling_ops = {
+ 1: max_pool1d,
+ 2: gen_nn_ops.max_pool,
+ 3: gen_nn_ops.max_pool3d
+ }
+
+ op = max_pooling_ops.get(n)
+ return op(
+ input,
+ ksize=ksize,
+ strides=strides,
+ padding=padding,
+ data_format=data_format,
+ name=name)
+
+
+# pylint: enable=redefined-builtin
+
+
+@tf_export(v1=["nn.max_pool"])
+def max_pool(value,
+ ksize,
+ strides,
+ padding,
+ data_format="NHWC",
+ name=None,
+ input=None): # pylint: disable=redefined-builtin
"""Performs the max pooling on the input.
Args:
@@ -3075,17 +3159,18 @@
See the "returns" section of `tf.nn.convolution` for details.
data_format: A string. 'NHWC', 'NCHW' and 'NCHW_VECT_C' are supported.
name: Optional name for the operation.
+ input: Alias for value.
Returns:
A `Tensor` of format specified by `data_format`.
The max pooled output tensor.
"""
+ value = deprecation.deprecated_argument_lookup("input", input, "value", value)
with ops.name_scope(name, "MaxPool", [value]) as name:
- value = ops.convert_to_tensor(value, name="input")
if data_format is None:
data_format = "NHWC"
-
channel_index = 1 if data_format.startswith("NC") else 3
+
ksize = _get_sequence(ksize, 2, channel_index, "ksize")
strides = _get_sequence(strides, 2, channel_index, "strides")
@@ -3099,6 +3184,88 @@
# pylint: disable=redefined-builtin
+@tf_export("nn.max_pool1d")
+def max_pool1d(input, ksize, strides, padding, data_format="NWC", name=None):
+ """Performs the max pooling on the input.
+
+ Note internally this op reshapes and uses the underlying 2d operation.
+
+ Args:
+ input: A 3-D `Tensor` of the format specified by `data_format`.
+ ksize: An int or list of `ints` that has length `1` or `3`. The size of the
+ window for each dimension of the input tensor.
+ strides: An int or list of `ints` that has length `1` or `3`. The stride of
+ the sliding window for each dimension of the input tensor.
+ padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. See
+ the "returns" section of `tf.nn.convolution` for details.
+ data_format: An optional string from: "NWC", "NCW". Defaults to "NWC".
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` of format specified by `data_format`.
+ The max pooled output tensor.
+ """
+ with ops.name_scope(name, "MaxPool1d", [input]) as name:
+ if data_format is None:
+ data_format = "NWC"
+ channel_index = 1 if data_format.startswith("NC") else 2
+ ksize = [1] + _get_sequence(ksize, 1, channel_index, "ksize")
+ strides = [1] + _get_sequence(strides, 1, channel_index, "strides")
+
+ data_format = "NHWC" if data_format == "NWC" else "NCHW"
+ expanding_dim = 1 if data_format == "NWC" else 2
+
+ input = array_ops.expand_dims_v2(input, expanding_dim)
+ result = gen_nn_ops.max_pool(
+ input,
+ ksize=ksize,
+ strides=strides,
+ padding=padding,
+ data_format=data_format,
+ name=name)
+ return array_ops.squeeze(result, expanding_dim)
+# pylint: enable=redefined-builtin
+
+
+# pylint: disable=redefined-builtin
+@tf_export("nn.max_pool2d")
+def max_pool2d(input, ksize, strides, padding, data_format="NHWC", name=None):
+ """Performs the max pooling on the input.
+
+ Args:
+ input: A 4-D `Tensor` of the format specified by `data_format`.
+ ksize: An int or list of `ints` that has length `1`, `2` or `4`. The size of
+ the window for each dimension of the input tensor.
+ strides: An int or list of `ints` that has length `1`, `2` or `4`. The
+ stride of the sliding window for each dimension of the input tensor.
+ padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. See
+ the "returns" section of `tf.nn.convolution` for details.
+ data_format: A string. 'NHWC', 'NCHW' and 'NCHW_VECT_C' are supported.
+ name: Optional name for the operation.
+
+ Returns:
+ A `Tensor` of format specified by `data_format`.
+ The max pooled output tensor.
+ """
+ with ops.name_scope(name, "MaxPool2d", [input]) as name:
+ if data_format is None:
+ data_format = "NHWC"
+ channel_index = 1 if data_format.startswith("NC") else 3
+
+ ksize = _get_sequence(ksize, 2, channel_index, "ksize")
+ strides = _get_sequence(strides, 2, channel_index, "strides")
+
+ return gen_nn_ops.max_pool(
+ input,
+ ksize=ksize,
+ strides=strides,
+ padding=padding,
+ data_format=data_format,
+ name=name)
+# pylint: enable=redefined-builtin
+
+
+# pylint: disable=redefined-builtin
@tf_export("nn.max_pool3d")
def max_pool3d(input, ksize, strides, padding, data_format="NDHWC", name=None):
"""Performs the max pooling on the input.
@@ -3138,8 +3305,6 @@
padding=padding,
data_format=data_format,
name=name)
-
-
# pylint: enable=redefined-builtin
@@ -3211,15 +3376,18 @@
strides,
padding,
data_format="NHWC",
- Targmax=dtypes.int64, # pylint: disable=invalid-name
+ Targmax=None, # pylint: disable=invalid-name
name=None,
output_dtype=None):
+ if data_format != "NHWC":
+ raise ValueError("Data formats other than 'NHWC' are not yet supported")
+
Targmax = deprecated_argument_lookup(
"output_dtype", output_dtype, "Targmax", Targmax)
- if output_dtype is not None:
- Targmax = output_dtype
+ if Targmax is None:
+ Targmax = dtypes.int64
return gen_nn_ops.max_pool_with_argmax(
- input, ksize, strides, padding, data_format, Targmax, name)
+ input, ksize, strides, padding, Targmax, name)
max_pool_with_argmax_v1.__doc__ = gen_nn_ops.max_pool_with_argmax.__doc__
# pylint: enable=redefined-builtin
@@ -3831,14 +3999,16 @@
"`NHWC` for data_format is deprecated, use `NWC` instead",
warn_once=True,
data_format="NHWC")
-def conv1d(value=None,
- filters=None,
- stride=None,
- padding=None,
- use_cudnn_on_gpu=None,
- data_format=None,
- name=None,
- input=None): # pylint: disable=redefined-builtin
+def conv1d(
+ value,
+ filters,
+ stride,
+ padding,
+ use_cudnn_on_gpu=None,
+ data_format=None,
+ name=None,
+ input=None, # pylint: disable=redefined-builtin
+ dilations=None):
r"""Computes a 1-D convolution given 3-D input and filter tensors.
Given an input tensor of shape
@@ -3866,8 +4036,8 @@
Args:
value: A 3D `Tensor`. Must be of type `float16`, `float32`, or `float64`.
filters: A 3D `Tensor`. Must have the same type as `value`.
- stride: An `integer`. The number of entries by which
- the filter is moved right at each step.
+ stride: An int or list of `ints` that has length `1` or `3`. The number of
+ entries by which the filter is moved right at each step.
padding: 'SAME' or 'VALID'
use_cudnn_on_gpu: An optional `bool`. Defaults to `True`.
data_format: An optional `string` from `"NWC", "NCW"`. Defaults
@@ -3876,6 +4046,10 @@
data as [batch, in_channels, in_width].
name: A name for the operation (optional).
input: Alias for value.
+ dilations: An int or list of `ints` that has length `1` or `3` which
+ defaults to 1. The dilation factor for each dimension of input. If set to
+ k > 1, there will be k-1 skipped cells between each filter element on that
+ dimension. Dilations in the batch and depth dimensions must be 1.
Returns:
A `Tensor`. Has the same type as input.
@@ -3889,13 +4063,16 @@
if data_format is None or data_format == "NHWC" or data_format == "NWC":
data_format = "NHWC"
spatial_start_dim = 1
- strides = [1, 1, stride, 1]
+ channel_index = 2
elif data_format == "NCHW" or data_format == "NCW":
data_format = "NCHW"
spatial_start_dim = 2
- strides = [1, 1, 1, stride]
+ channel_index = 1
else:
raise ValueError("data_format must be \"NWC\" or \"NCW\".")
+ strides = [1] + _get_sequence(stride, 1, channel_index, "stride")
+ dilations = [1] + _get_sequence(dilations, 1, channel_index, "dilations")
+
value = array_ops.expand_dims(value, spatial_start_dim)
filters = array_ops.expand_dims(filters, 0)
result = gen_nn_ops.conv2d(
@@ -3904,17 +4081,21 @@
strides,
padding,
use_cudnn_on_gpu=use_cudnn_on_gpu,
- data_format=data_format)
+ data_format=data_format,
+ dilations=dilations,
+ name=name)
return array_ops.squeeze(result, [spatial_start_dim])
@tf_export("nn.conv1d", v1=[])
-def conv1d_v2(input, # pylint: disable=redefined-builtin
- filters,
- stride,
- padding,
- data_format=None,
- name=None):
+def conv1d_v2(
+ input, # pylint: disable=redefined-builtin
+ filters,
+ stride,
+ padding,
+ data_format="NWC",
+ dilations=None,
+ name=None):
r"""Computes a 1-D convolution given 3-D input and filter tensors.
Given an input tensor of shape
@@ -3942,13 +4123,17 @@
Args:
input: A 3D `Tensor`. Must be of type `float16`, `float32`, or `float64`.
filters: A 3D `Tensor`. Must have the same type as `input`.
- stride: An `integer`. The number of entries by which
- the filter is moved right at each step.
+ stride: An int or list of `ints` that has length `1` or `3`. The number of
+ entries by which the filter is moved right at each step.
padding: 'SAME' or 'VALID'
data_format: An optional `string` from `"NWC", "NCW"`. Defaults
to `"NWC"`, the data is stored in the order of
[batch, in_width, in_channels]. The `"NCW"` format stores
data as [batch, in_channels, in_width].
+ dilations: An int or list of `ints` that has length `1` or `3` which
+ defaults to 1. The dilation factor for each dimension of input. If set to
+ k > 1, there will be k-1 skipped cells between each filter element on that
+ dimension. Dilations in the batch and depth dimensions must be 1.
name: A name for the operation (optional).
Returns:
@@ -3957,13 +4142,15 @@
Raises:
ValueError: if `data_format` is invalid.
"""
- return conv1d(input, # pylint: disable=redefined-builtin
- filters,
- stride,
- padding,
- use_cudnn_on_gpu=True,
- data_format=data_format,
- name=name)
+ return conv1d(
+ input, # pylint: disable=redefined-builtin
+ filters,
+ stride,
+ padding,
+ use_cudnn_on_gpu=True,
+ data_format=data_format,
+ name=name,
+ dilations=dilations)
def conv1d_transpose(
@@ -3988,21 +4175,22 @@
filter: A 3-D `Tensor` with the same type as `value` and shape
`[filter_width, output_channels, in_channels]`. `filter`'s
`in_channels` dimension must match that of `value`.
- output_shape: A 1-D `Tensor` representing the output shape of the
- deconvolution op.
+ output_shape: A 1-D `Tensor`, containing three elements, representing the
+ output shape of the deconvolution op.
stride: An `integer`. The number of entries by which
the filter is moved right at each step.
padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
See the "returns" section of `tf.nn.convolution` for details.
- data_format: A string. 'NHWC' and 'NCHW' are supported.
+ data_format: A string. `'NWC'` and `'NCW'` are supported.
name: Optional name for the returned tensor.
Returns:
A `Tensor` with the same type as `value`.
Raises:
- ValueError: If input/output depth does not match `filter`'s shape, or if
- padding is other than `'VALID'` or `'SAME'`.
+ ValueError: If input/output depth does not match `filter`'s shape, if
+ `output_shape` is not at 3-element vector, if `padding` is other than
+ `'VALID'` or `'SAME'`, or if `data_format` is invalid.
"""
with ops.name_scope(name, "conv1d_transpose",
[value, filter, output_shape]) as name:
diff --git a/tensorflow/python/ops/nn_test.py b/tensorflow/python/ops/nn_test.py
index 7456134..3279dca 100644
--- a/tensorflow/python/ops/nn_test.py
+++ b/tensorflow/python/ops/nn_test.py
@@ -1241,5 +1241,81 @@
self.assertAllEqual(y_val, [[7, 4], [4, 5], [5, 1], [9, 3]])
+@test_util.run_all_in_graph_and_eager_modes
+class MaxPoolTest(test_lib.TestCase):
+
+ def test1DTensor(self):
+ x = array_ops.ones([3, 6, 5])
+ ksize = 2
+ strides = 2
+
+ y1 = nn_ops.max_pool_v2(x, ksize, strides, "SAME")
+ y2 = nn_ops.max_pool1d(x, ksize, strides, "SAME")
+
+ self.assertAllEqual(self.evaluate(y1), self.evaluate(y2))
+
+ def test1DNumpy(self):
+ x = np.ones([3, 6, 5])
+ ksize = 2
+ strides = 2
+
+ y1 = nn_ops.max_pool_v2(x, ksize, strides, "SAME")
+ y2 = nn_ops.max_pool1d(x, ksize, strides, "SAME")
+
+ self.assertAllEqual(self.evaluate(y1), self.evaluate(y2))
+
+ def test2DTensor(self):
+ x = array_ops.ones([3, 6, 6, 5])
+ ksize = 2
+ strides = 2
+
+ y1 = nn_ops.max_pool_v2(x, ksize, strides, "SAME")
+ y2 = nn_ops.max_pool(x, ksize, strides, "SAME")
+
+ self.assertAllEqual(self.evaluate(y1), self.evaluate(y2))
+
+ def test2DNumpy(self):
+ x = np.ones([3, 6, 6, 5])
+ ksize = 2
+ strides = 2
+
+ y1 = nn_ops.max_pool_v2(x, ksize, strides, "SAME")
+ y2 = nn_ops.max_pool(x, ksize, strides, "SAME")
+
+ self.assertAllEqual(self.evaluate(y1), self.evaluate(y2))
+
+ def test3DTensor(self):
+ x = array_ops.ones([3, 7, 6, 6, 5])
+ ksize = 2
+ strides = 2
+
+ y1 = nn_ops.max_pool_v2(x, ksize, strides, "SAME")
+ y2 = nn_ops.max_pool3d(x, ksize, strides, "SAME")
+
+ self.assertAllEqual(self.evaluate(y1), self.evaluate(y2))
+
+ def test3DNumpy(self):
+ x = np.ones([3, 7, 6, 6, 5], dtype=np.float32)
+ ksize = 2
+ strides = 2
+
+ y1 = nn_ops.max_pool_v2(x, ksize, strides, "SAME")
+ y2 = nn_ops.max_pool3d(x, ksize, strides, "SAME")
+
+ self.assertAllEqual(self.evaluate(y1), self.evaluate(y2))
+
+ def testIncorrectSizeInputSmall(self):
+ x = array_ops.ones([3, 4])
+ with self.assertRaisesRegex(
+ ValueError, "Input tensor must be of rank 3, 4 or 5 but was 2."):
+ nn_ops.max_pool_v2(x, 2, 2, "SAME")
+
+ def testIncorrectSizeInput(self):
+ x = array_ops.ones([3, 4, 1, 2, 1, 2])
+ with self.assertRaisesRegex(
+ ValueError, "Input tensor must be of rank 3, 4 or 5 but was 6."):
+ nn_ops.max_pool_v2(x, 2, 2, "SAME")
+
+
if __name__ == "__main__":
test_lib.main()
diff --git a/tensorflow/python/ops/parallel_for/math_test.py b/tensorflow/python/ops/parallel_for/math_test.py
index db88f4f..7a5bef7 100644
--- a/tensorflow/python/ops/parallel_for/math_test.py
+++ b/tensorflow/python/ops/parallel_for/math_test.py
@@ -278,7 +278,7 @@
x = random_ops.random_uniform([2, 3, 4, 5])
for op in [
math_ops.reduce_sum, math_ops.reduce_prod, math_ops.reduce_max,
- math_ops.reduce_min
+ math_ops.reduce_min, math_ops.reduce_mean,
]:
for axis in ([1], None, [0, 2]):
for keepdims in (True, False):
@@ -325,26 +325,46 @@
self._test_loop_fn(loop_fn, 2)
def test_bias_add(self):
- x_shape = [2, 3, 4, 5, 6]
- x = random_ops.random_uniform(x_shape)
for data_format in ("NCHW", "NHWC"):
- with backprop.GradientTape(persistent=True) as g:
- bias_dim = 2 if data_format == "NCHW" else -1
- bias_shape = x_shape[bias_dim]
- bias = random_ops.random_uniform([bias_shape])
- g.watch(bias)
+ for stacked_value in (True, False):
+ x_shape = [3, 4, 5, 6]
+ if stacked_value:
+ x_shape = [2] + x_shape
+ x = random_ops.random_uniform(x_shape)
+ for stacked_bias in (True, False):
+ if not (stacked_value or stacked_bias):
+ continue
+ with backprop.GradientTape(persistent=True) as g:
+ bias_dim = -1
+ if data_format == "NCHW":
+ bias_dim = 2 if stacked_value else 1
+ bias_shape = [x_shape[bias_dim]]
+ if stacked_bias:
+ bias_shape = [2] + bias_shape
+ bias = random_ops.random_uniform(bias_shape)
+ g.watch(bias)
- # pylint: disable=cell-var-from-loop
- def loop_fn(i):
- with g:
- a = array_ops.gather(x, i)
- y = nn.bias_add(a, bias, data_format=data_format)
- loss = math_ops.reduce_sum(y * y)
- return y, g.gradient(loss, bias)
- # pylint: enable=cell-var-from-loop
+ # pylint: disable=cell-var-from-loop
+ def loop_fn(i):
+ with g:
+ a = array_ops.gather(x, i) if stacked_value else x
+ b = array_ops.gather(bias, i) if stacked_bias else bias
+ y = nn.bias_add(a, b, data_format=data_format)
+ loss = math_ops.reduce_sum(y * y)
+ grad = g.gradient(loss, bias)
+ if stacked_bias:
+ # If we gather over bias in loop_fn, the gradient will be an
+ # instance of `IndexedSlices` with attrs `values` and `indices`.
+ return y, grad.values, grad.indices
+ else:
+ return y, grad
+ # pylint: enable=cell-var-from-loop
- self._test_loop_fn(
- loop_fn, 2, loop_fn_dtypes=[dtypes.float32, dtypes.float32])
+ out_dtypes = [dtypes.float32, dtypes.float32]
+ if stacked_bias:
+ out_dtypes = out_dtypes + [dtypes.int32]
+ self._test_loop_fn(
+ loop_fn, 2, loop_fn_dtypes=out_dtypes)
def test_unsorted_segment_sum(self):
t = random_ops.random_uniform([3, 3, 2])
diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py
index b0f6a6a..019d4f2 100644
--- a/tensorflow/python/ops/parallel_for/pfor.py
+++ b/tensorflow/python/ops/parallel_for/pfor.py
@@ -42,6 +42,7 @@
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.platform import flags
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import compat
from tensorflow.python.util import nest
flags.DEFINE_bool(
@@ -1876,6 +1877,7 @@
@RegisterPForWithArgs("Prod", math_ops.reduce_prod)
@RegisterPForWithArgs("Max", math_ops.reduce_max)
@RegisterPForWithArgs("Min", math_ops.reduce_min)
+@RegisterPForWithArgs("Mean", math_ops.reduce_mean)
def _convert_reduction(pfor_input, _, op_func):
t = pfor_input.stacked_input(0)
indices = pfor_input.unstacked_input(1)
@@ -1899,17 +1901,30 @@
@RegisterPFor("BiasAdd")
def _convert_biasadd(pfor_input):
- t = pfor_input.stacked_input(0)
- bias = pfor_input.unstacked_input(1)
+ t, t_stacked, _ = pfor_input.input(0)
+ bias, bias_stacked, _ = pfor_input.input(1)
data_format = pfor_input.get_attr("data_format")
- if data_format != b"NCHW":
+ if bias_stacked:
+ # BiasAdd only supports 1-D biases, so cast bias to match value and use Add.
+ pfor_input.expanddim_inputs_for_broadcast()
+ t, _, _ = pfor_input.input(0)
+ bias = math_ops.cast(pfor_input.stacked_input(1), t.dtype)
+ if compat.as_bytes(data_format) == b"NCHW":
+ b_shape = array_ops.shape(bias)
+ new_b_shape = array_ops.concat(
+ [b_shape[:-3], b_shape[-1:], b_shape[-3:-1]], axis=0)
+ bias = array_ops.reshape(bias, new_b_shape)
+ return wrap(math_ops.add(t, bias), True)
+ else:
+ assert t_stacked, "At least one input to BiasAdd should be loop variant."
+ if compat.as_bytes(data_format) == b"NCHW":
+ shape = array_ops.shape(t)
+ flattened_shape = array_ops.concat([[-1], shape[2:]], axis=0)
+ t = array_ops.reshape(t, flattened_shape)
+ t = nn_ops.bias_add(t, bias, data_format=b"NCHW")
+ t = array_ops.reshape(t, shape)
+ return wrap(t, True)
return wrap(nn_ops.bias_add(t, bias, data_format=data_format), True)
- shape = array_ops.shape(t)
- flattened_shape = array_ops.concat([[-1], shape[2:]], axis=0)
- t = array_ops.reshape(t, flattened_shape)
- t = nn_ops.bias_add(t, bias, data_format=b"NCHW")
- t = array_ops.reshape(t, shape)
- return wrap(t, True)
@RegisterPFor("UnsortedSegmentSum")
diff --git a/tensorflow/python/ops/ragged/BUILD b/tensorflow/python/ops/ragged/BUILD
index e3bdb74..4f29fcc 100644
--- a/tensorflow/python/ops/ragged/BUILD
+++ b/tensorflow/python/ops/ragged/BUILD
@@ -25,6 +25,7 @@
deps = [
":ragged_array_ops",
":ragged_batch_gather_ops",
+ ":ragged_batch_gather_with_default_op",
":ragged_concat_ops",
":ragged_conversion_ops",
":ragged_dispatch",
@@ -91,6 +92,29 @@
)
py_library(
+ name = "ragged_batch_gather_with_default_op",
+ srcs = [
+ "ragged_batch_gather_with_default_op.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":ragged_array_ops",
+ ":ragged_batch_gather_ops",
+ ":ragged_concat_ops",
+ ":ragged_dispatch",
+ ":ragged_operators",
+ ":ragged_tensor",
+ ":ragged_tensor_shape",
+ ":ragged_where_op",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:check_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ ],
+)
+
+py_library(
name = "ragged_concat_ops",
srcs = ["ragged_concat_ops.py"],
srcs_version = "PY2AND3",
@@ -365,11 +389,13 @@
srcs_version = "PY2AND3",
deps = [
":ragged_array_ops",
+ ":ragged_batch_gather_ops",
":ragged_math_ops",
":ragged_tensor",
":ragged_tensor_shape",
":ragged_util",
"//tensorflow/python:array_ops",
+ "//tensorflow/python:bitwise_ops",
"//tensorflow/python:clip_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
@@ -509,6 +535,7 @@
deps = [
":ragged_array_ops",
":ragged_batch_gather_ops",
+ ":ragged_batch_gather_with_default_op",
":ragged_factory_ops",
":ragged_tensor",
":ragged_test_util",
diff --git a/tensorflow/python/ops/ragged/__init__.py b/tensorflow/python/ops/ragged/__init__.py
index a5ffd8a..e9232a1 100644
--- a/tensorflow/python/ops/ragged/__init__.py
+++ b/tensorflow/python/ops/ragged/__init__.py
@@ -30,6 +30,7 @@
from tensorflow.python.ops.ragged import ragged_array_ops
from tensorflow.python.ops.ragged import ragged_batch_gather_ops
+from tensorflow.python.ops.ragged import ragged_batch_gather_with_default_op
from tensorflow.python.ops.ragged import ragged_concat_ops
from tensorflow.python.ops.ragged import ragged_conversion_ops
from tensorflow.python.ops.ragged import ragged_dispatch
diff --git a/tensorflow/python/ops/ragged/ragged_batch_gather_op_test.py b/tensorflow/python/ops/ragged/ragged_batch_gather_op_test.py
index 72692fc..17c55eb 100644
--- a/tensorflow/python/ops/ragged/ragged_batch_gather_op_test.py
+++ b/tensorflow/python/ops/ragged/ragged_batch_gather_op_test.py
@@ -21,11 +21,13 @@
from absl.testing import parameterized
from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.ragged import ragged_batch_gather_ops
+from tensorflow.python.ops.ragged import ragged_batch_gather_with_default_op
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_test_util
@@ -149,6 +151,324 @@
result = ragged_batch_gather_ops.batch_gather(params, indices)
self.assertRaggedEqual(result, expected)
+ @parameterized.parameters([
+ # Docstring example:
+ dict(
+ descr='Docstring example',
+ params=[['a', 'b', 'c'], ['d'], [], ['e']],
+ indices=[[1, 2, -1], [], [], [0, 10]],
+ expected=[['b', 'c', 'FOO'], [], [], ['e', 'FOO']],
+ default_value='FOO',
+ ),
+ # Dimensions:
+ # indices: [4]
+ # params: [2, (d1), (d2)]
+ dict(
+ descr='params: [2, (d1), (d2), indices: [4]',
+ indices=[1, 100, 0, -1],
+ params=[[['The', 'deal', 'came', 'about', '18', 'months', 'after',
+ 'Yahoo', '!', 'rejected', 'a', '47.5', '-', 'billion', '-',
+ 'dollar', 'takeover', 'offer', 'from', 'Microsoft', '.'],
+ ['Trumpty', 'Dumpty', 'sat', 'on', 'a', 'wall']],
+ [["It's", 'always', 'darkest', 'before', 'the', 'dawn']]],
+ expected=[[["It's", 'always', 'darkest', 'before', 'the', 'dawn']],
+ [['$NONE^']],
+ [['The', 'deal', 'came', 'about', '18', 'months', 'after',
+ 'Yahoo', '!', 'rejected', 'a', '47.5', '-', 'billion',
+ '-', 'dollar', 'takeover', 'offer', 'from', 'Microsoft',
+ '.'],
+ ['Trumpty', 'Dumpty', 'sat', 'on', 'a', 'wall']],
+ [['$NONE^']]],
+ ),
+ # Dimensions:
+ # params: [1, (d1)]
+ # indices: [3]
+ dict(
+ descr='params: rank 2, indices: rank 1',
+ params=[
+ ['Bruce', 'Wayne'],
+ ],
+ indices=[-1, 0, 1000],
+ expected=[['$NONE^'], ['Bruce', 'Wayne'], ['$NONE^']]
+ ),
+ # Dimensions:
+ # params: [1, (d1)]
+ # indices: [1, (d2)]
+ dict(
+ descr='Test underbound indices of shape [1, (d2)]',
+ params=[
+ ['The', 'deal', 'came', 'about', '18', 'months', 'after', 'Yahoo',
+ '!', 'rejected', 'a', '47.5', '-', 'billion', '-', 'dollar',
+ 'takeover', 'offer', 'from', 'Microsoft', '.'],
+ ],
+ indices=[[8, -1]],
+ expected=[['!', '$NONE^']],
+ ),
+ dict(
+ descr='Test underbound indices of shape [2, (d2)]',
+ params=[
+ ['The', 'deal', 'came', 'about', '18', 'months', 'after', 'Yahoo',
+ '!', 'rejected', 'a', '47.5', '-', 'billion', '-', 'dollar',
+ 'takeover', 'offer', 'from', 'Microsoft', '.'],
+ ['Who', 'let', 'the', 'dogs', 'out', '?'],
+ ],
+ indices=[[8, -1], [1, 100]],
+ expected=[['!', '$NONE^'], ['let', '$NONE^']],
+ ),
+ # Dimensions:
+ # params: [2, (d1)]
+ # indices: [2, (d2)]
+ dict(
+ descr='Test underbound indices of rank 2',
+ params=[
+ ['The', 'deal', 'came', 'about', '18', 'months', 'after', 'Yahoo',
+ '!', 'rejected', 'a', '47.5', '-', 'billion', '-', 'dollar',
+ 'takeover', 'offer', 'from', 'Microsoft', '.'],
+ ['He', 'left', 'us', '.', 'Little', 'boys', 'crowded', 'together',
+ 'on', 'long', 'wooden', 'benches', ',', 'and', 'in', 'the',
+ 'center', 'of', 'the', 'room', 'sat', 'the', 'teacher', '.',
+ 'His', 'black', 'beard', 'dripped', 'down', 'over', 'the',
+ 'front', 'of', 'his', 'coat', '.', 'One', 'white', 'hand',
+ 'poised', 'a', 'stick', 'above', 'his', 'desk', '.', 'He',
+ 'turned', 'his', 'surly', ',', 'half', '-', 'closed', 'eyes',
+ 'toward', 'us', ',', 'stared', 'for', 'a', 'second', ',', 'then',
+ 'shouted', 'in', 'Yiddish', ',', '``', 'One', ',', 'two', ',',
+ 'three', "''", '!', '!', 'Rapping', 'the', 'stick', 'against',
+ 'the', 'desk', '.', 'The', 'little', 'boys', 'shrilled', 'out',
+ 'a', 'Yiddish', 'translation', 'or', 'interpretation', 'of',
+ 'the', 'Five', 'Books', 'of', 'Moses', ',', 'which', 'they',
+ 'had', 'previously', 'chanted', 'in', 'Hebrew', '.']],
+ indices=[[8, -1], [3, 23, 35, 45, 75, 83, -121]],
+ expected=[['!', '$NONE^'], ['.', '.', '.', '.', '!', '.', '$NONE^']],
+ ),
+ dict(
+ descr='Test overbound indices of rank 2',
+ params=[
+ ['The', 'deal', 'came', 'about', '18', 'months', 'after', 'Yahoo',
+ '!', 'rejected', 'a', '47.5', '-', 'billion', '-', 'dollar',
+ 'takeover', 'offer', 'from', 'Microsoft', '.'],
+ ['He', 'left', 'us', '.', 'Little', 'boys', 'crowded', 'together',
+ 'on', 'long', 'wooden', 'benches', ',', 'and', 'in', 'the',
+ 'center', 'of', 'the', 'room', 'sat', 'the', 'teacher', '.',
+ 'His', 'black', 'beard', 'dripped', 'down', 'over', 'the',
+ 'front', 'of', 'his', 'coat', '.', 'One', 'white', 'hand',
+ 'poised', 'a', 'stick', 'above', 'his', 'desk', '.', 'He',
+ 'turned', 'his', 'surly', ',', 'half', '-', 'closed', 'eyes',
+ 'toward', 'us', ',', 'stared', 'for', 'a', 'second', ',', 'then',
+ 'shouted', 'in', 'Yiddish', ',', '``', 'One', ',', 'two', ',',
+ 'three', "''", '!', '!', 'Rapping', 'the', 'stick', 'against',
+ 'the', 'desk', '.', 'The', 'little', 'boys', 'shrilled', 'out',
+ 'a', 'Yiddish', 'translation', 'or', 'interpretation', 'of',
+ 'the', 'Five', 'Books', 'of', 'Moses', ',', 'which', 'they',
+ 'had', 'previously', 'chanted', 'in', 'Hebrew', '.']],
+ indices=[[8, 8823], [3, 23, 35, 45, 75, 83, 1234]],
+ expected=[['!', '$NONE^'], ['.', '.', '.', '.', '!', '.', '$NONE^']],
+ ),
+ # Dimensions:
+ # params: [2, (d1), 2]
+ # indices: [2, (d2)]
+ dict(
+ descr='params: rank 3, indices: rank 2',
+ params=[
+ [['The', 'deal'], ['takeover', 'offer'], ['from', 'Microsoft']],
+ [['Who', 'let'], ['the', 'dogs'], ['out', '?']],
+ ],
+ ragged_rank=1,
+ indices=[[1, -1, 2, 30], [1, 100]],
+ indices_ragged_rank=1,
+ expected=[[['takeover', 'offer'],
+ ['$NONE^', '$NONE^'],
+ ['from', 'Microsoft'],
+ ['$NONE^', '$NONE^']],
+ [['the', 'dogs'],
+ ['$NONE^', '$NONE^']]],
+ expected_ragged_rank=1,
+ default_value=['$NONE^', '$NONE^'],
+ ),
+ # Dimensions:
+ # params: [2, (d1), (d2)]
+ # indices: [2, (d3)]
+ dict(
+ descr='params: [2, (d1), (d2)], indices: [2, (d3)]',
+ params=[
+ [['The', 'deal', 'came', 'about', '18', 'months', 'after',
+ 'Yahoo', '!', 'rejected', 'a', '47.5', '-', 'billion', '-',
+ 'dollar', 'takeover', 'offer', 'from', 'Microsoft', '.'],
+ ['Trumpty', 'Dumpty', 'sat', 'on', 'a', 'wall'],
+ ],
+ [['It\'s', 'always', 'darkest', 'before', 'the', 'dawn']]
+ ],
+ indices=[[1, 100], [0, -1]],
+ expected=[[['Trumpty', 'Dumpty', 'sat', 'on', 'a', 'wall'],
+ ['$NONE^']],
+ [["It's", 'always', 'darkest', 'before', 'the', 'dawn'],
+ ['$NONE^']]]
+ ),
+ # Dimensions:
+ # params: [2, (d1), (d2)]
+ # indices: [2, (d1), (d3)]
+ dict(
+ descr='Test overbound indices of rank 3',
+ params=[
+ [['The', 'deal', 'came', 'about', '18', 'months', 'after',
+ 'Yahoo', '!', 'rejected', 'a', '47.5', '-', 'billion', '-',
+ 'dollar', 'takeover', 'offer', 'from', 'Microsoft', '.'],
+ ['Foo', 'bar', 'mar']],
+ [['He', 'left', 'us', '.', 'Little', 'boys', 'crowded',
+ 'together', 'on', 'long', 'wooden', 'benches', ',', 'and', 'in',
+ 'the', 'center', 'of', 'the', 'room', 'sat', 'the', 'teacher',
+ '.', 'His', 'black', 'beard', 'dripped', 'down', 'over', 'the',
+ 'front', 'of', 'his', 'coat', '.', 'One', 'white', 'hand',
+ 'poised', 'a', 'stick', 'above', 'his', 'desk', '.', 'He',
+ 'turned', 'his', 'surly', ',', 'half', '-', 'closed', 'eyes',
+ 'toward', 'us', ',', 'stared', 'for', 'a', 'second', ',',
+ 'then', 'shouted', 'in', 'Yiddish', ',', '``', 'One', ',',
+ 'two', ',',
+ 'three', "''", '!', '!', 'Rapping', 'the', 'stick', 'against',
+ 'the', 'desk', '.', 'The', 'little', 'boys', 'shrilled', 'out',
+ 'a', 'Yiddish', 'translation', 'or', 'interpretation', 'of',
+ 'the', 'Five', 'Books', 'of', 'Moses', ',', 'which', 'they',
+ 'had', 'previously', 'chanted', 'in', 'Hebrew', '.'],
+ ['I', 'too', 'was', 'hustled', 'scammed', 'bamboozled', 'hood',
+ 'winked', 'lead', 'astray']]
+ ],
+ indices=[[[8, 8823], [0, 100]], [[3, 23, 35, 45, 75, 83, 1234], [5]]],
+ expected=[[['!', '$NONE^'], ['Foo', '$NONE^']],
+ [['.', '.', '.', '.', '!', '.', '$NONE^'],
+ ['bamboozled']]],
+ ),
+ # params.shape = [2, (d1), 8]
+ # indices.shape = [2, (d1), 3]
+ dict(
+ descr='params = [2, (2, 1), 8], indices = [2, (2, 1), 3]',
+ params=[[['h'] * 8, ['w'] * 8], [['b'] * 8]],
+ ragged_rank=1,
+ indices=[[[0, 100, 1], [0, 1, 0]], [[1, 0, 0]]],
+ indices_ragged_rank=1,
+ expected=[[['h', '$NONE^', 'h'], ['w', 'w', 'w']], [['b', 'b', 'b']]],
+ expected_ragged_rank=1,
+ ),
+ ])
+ def testRaggedBatchGatherWithDefault(
+ self, descr, params, indices, expected, indices_ragged_rank=None,
+ expected_ragged_rank=None, ragged_rank=None, default_value='$NONE^'):
+ params = ragged_factory_ops.constant(params, ragged_rank=ragged_rank)
+ indices = ragged_factory_ops.constant(
+ indices, ragged_rank=indices_ragged_rank or ragged_rank)
+ expected = ragged_factory_ops.constant(
+ expected, ragged_rank=expected_ragged_rank or ragged_rank)
+ result = ragged_batch_gather_with_default_op.batch_gather_with_default(
+ params, indices, default_value)
+ self.assertRaggedEqual(result, expected)
+
+ @parameterized.parameters([
+ # Dimensions:
+ # params: dims [2, 5], indices: [2, 2]
+ dict(
+ descr='params: dims [2, 5], indices: [2, 2]',
+ params=[
+ ['The', 'deal', 'came', 'about', '18'],
+ ['He', 'left', 'us', '.', 'Little']],
+ indices=[[0, -1], [3, 121]],
+ expected=[['The', '$NONE^'], ['.', '$NONE^']],
+ default_value='$NONE^',
+ ),
+ # Dimensions:
+ # params: dims [2, 2, 5], indices: [2, 2]
+ dict(
+ descr='params: dims [2, 2, 5], indices: [2, 2]',
+ params=[
+ [['The', 'deal', 'came', 'about', '18'],
+ ['The', 'deal', 'came', 'about', '19'],
+ ],
+ [['He', 'left', 'us', '.', 'Little'],
+ ['The', 'deal', 'came', 'about', '20'],
+ ]
+ ],
+ indices=[[0, -1], [0, 121]],
+ expected=[[['The', 'deal', 'came', 'about', '18'],
+ ['$NONE^', '$NONE^', '$NONE^', '$NONE^', '$NONE^']],
+ [['He', 'left', 'us', '.', 'Little'],
+ ['$NONE^', '$NONE^', '$NONE^', '$NONE^', '$NONE^']]],
+ default_value='$NONE^',
+ ),
+ # Test default_value with shape [5]
+ dict(
+ descr='params: dims [2, 2, 5], indices: [2, 2]',
+ params=[
+ [['The', 'deal', 'came', 'about', '18'],
+ ['The', 'deal', 'came', 'about', '19'],
+ ],
+ [['He', 'left', 'us', '.', 'Little'],
+ ['The', 'deal', 'came', 'about', '20'],
+ ]
+ ],
+ indices=[[0, -1], [0, 121]],
+ expected=[[['The', 'deal', 'came', 'about', '18'],
+ [':FOO:', ':FOO:', ':FOO:', ':FOO:', ':FOO:']],
+ [['He', 'left', 'us', '.', 'Little'],
+ [':FOO:', ':FOO:', ':FOO:', ':FOO:', ':FOO:']]],
+ default_value=[':FOO:', ':FOO:', ':FOO:', ':FOO:', ':FOO:'],
+ ),
+ ])
+ def testRaggedBatchGatherWithDefaultOnTensors(
+ self, descr, params, indices, expected, default_value):
+ params = constant_op.constant(params)
+ indices = constant_op.constant(indices)
+ expected = constant_op.constant(expected)
+ result = ragged_batch_gather_with_default_op.batch_gather_with_default(
+ params, indices, default_value)
+ self.assertAllEqual(expected, result)
+
+ @parameterized.parameters([
+ dict(
+ params=[['The', 'deal', 'came', 'about', '18', 'months', 'after',
+ 'Yahoo', '!', 'rejected', 'a', '47.5', '-', 'billion', '-',
+ 'dollar', 'takeover', 'offer', 'from', 'Microsoft', '.']],
+ indices=[[[8, -1]]],
+ # Exception here because different errors are thrown in eager vs
+ # graph mode.
+ error=Exception,
+ default_value='$NONE^',
+ ),
+ ])
+ def testRankMismatch(
+ self, params, indices, default_value, error):
+ params = ragged_factory_ops.constant(params)
+ indices = ragged_factory_ops.constant(indices)
+ with self.assertRaises(error):
+ _ = ragged_batch_gather_with_default_op.batch_gather_with_default(
+ params, indices, default_value)
+
+ @parameterized.parameters([
+ # Dimensions:
+ # params: [2, (d1), 2]
+ # indices: [2, (d2)]
+ # default_value: []
+ dict(
+ descr='params: rank 3, indices: rank 2, default: rank = [], but'
+ ' should be [2]',
+ params=[
+ [['The', 'deal'], ['takeover', 'offer'], ['from', 'Microsoft']],
+ [['Who', 'let'], ['the', 'dogs'], ['out', '?']],
+ ],
+ ragged_rank=1,
+ indices=[[1, -1, 2, 30], [1, 100]],
+ indices_ragged_rank=1,
+ default_value='$NONE^',
+ error=Exception,
+ )
+ ])
+ def testInvalidDefaultValueRank(
+ self, descr, params, indices, default_value, error, ragged_rank=None,
+ indices_ragged_rank=None):
+ params = ragged_factory_ops.constant(params, ragged_rank=ragged_rank)
+ indices = ragged_factory_ops.constant(
+ indices, ragged_rank=indices_ragged_rank)
+ with self.assertRaises(error):
+ _ = ragged_batch_gather_with_default_op.batch_gather_with_default(
+ params, indices, default_value)
+
def testRaggedBatchGatherUnknownRankError(self):
if context.executing_eagerly():
return
diff --git a/tensorflow/python/ops/ragged/ragged_batch_gather_with_default_op.py b/tensorflow/python/ops/ragged/ragged_batch_gather_with_default_op.py
new file mode 100644
index 0000000..0d99540
--- /dev/null
+++ b/tensorflow/python/ops/ragged/ragged_batch_gather_with_default_op.py
@@ -0,0 +1,186 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Array operations for RaggedTensors."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.ragged import ragged_array_ops
+from tensorflow.python.ops.ragged import ragged_dispatch # pylint: disable=unused-import
+from tensorflow.python.ops.ragged import ragged_operators # pylint: disable=unused-import
+from tensorflow.python.ops.ragged import ragged_tensor
+from tensorflow.python.ops.ragged import ragged_tensor_shape
+from tensorflow.python.ops.ragged import ragged_where_op
+
+
+#===============================================================================
+# ragged.batch_gather_with_default
+#===============================================================================
+def batch_gather_with_default(params,
+ indices,
+ default_value='',
+ name=None):
+ """Same as `batch_gather` but inserts `default_value` for invalid indices.
+
+ This operation is similar to `batch_gather` except that it will substitute
+ the value for invalid indices with `default_value` as the contents.
+ See `batch_gather` for more details.
+
+
+ Args:
+ params: A potentially ragged tensor with shape `[B1...BN, P1...PM]` (`N>=0`,
+ `M>0`).
+ indices: A potentially ragged tensor with shape `[B1...BN, I]` (`N>=0`).
+ default_value: A value to be inserted in places where `indices` are out of
+ bounds. Must be the same dtype as params and either a scalar or rank 1.
+ name: A name for the operation (optional).
+
+ Returns:
+ A potentially ragged tensor with shape `[B1...BN, I, P2...PM]`.
+ `result.ragged_rank = max(indices.ragged_rank, params.ragged_rank)`.
+
+ #### Example:
+ ```python
+ >>> params = tf.ragged.constant([
+ ['a', 'b', 'c'],
+ ['d'],
+ [],
+ ['e']])
+ >>> indices = tf.ragged.constant([[1, 2, -1], [], [], [0, 10]])
+ >>> batch_gather_with_default(params, indices, 'FOO')
+ [['b', 'c', 'FOO'], [], [], ['e', 'FOO']]
+ ```
+ """
+ with ops.name_scope(name, 'RaggedBatchGatherWithDefault'):
+ params = ragged_tensor.convert_to_tensor_or_ragged_tensor(
+ params, name='params',
+ )
+ indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
+ indices, name='indices',
+ )
+ default_value = ragged_tensor.convert_to_tensor_or_ragged_tensor(
+ default_value, name='default_value',
+ )
+ # TODO(hterry): lift this restriction and support default_values of
+ # of rank > 1
+ if (default_value.shape.ndims is not 0
+ and default_value.shape.ndims is not 1):
+ raise ValueError('"default_value" must be a scalar or vector')
+ upper_bounds = None
+ if indices.shape.ndims is None:
+ raise ValueError('Indices must have a known rank.')
+ if params.shape.ndims is None:
+ raise ValueError('Params must have a known rank.')
+
+ num_batch_dimensions = indices.shape.ndims - 1
+ pad = None
+ # The logic for this works as follows:
+ # - create a padded params, where:
+ # padded_params[b1...bn, 0] = default_value
+ # padded_params[b1...bn, i] = params[b1...bn, i-1] (i>0)
+ # - create an `upper_bounds` Tensor that contains the number of elements
+ # in each innermost rank. Broadcast `upper_bounds` to be the same shape
+ # as `indices`.
+ # - check to see which index in `indices` are out of bounds and substitute
+ # it with the index containing `default_value` (the first).
+ # - call batch_gather with the indices adjusted.
+ with ops.control_dependencies([
+ check_ops.assert_greater_equal(array_ops.rank(params),
+ array_ops.rank(indices))]):
+ if ragged_tensor.is_ragged(params):
+ row_lengths = ragged_array_ops.expand_dims(
+ params.row_lengths(axis=num_batch_dimensions),
+ axis=-1)
+ upper_bounds = math_ops.cast(row_lengths, indices.dtype)
+
+ pad_shape = _get_pad_shape(params, indices)
+
+ pad = ragged_tensor_shape.broadcast_to(
+ default_value, pad_shape)
+ else:
+ params_shape = array_ops.shape(params)
+ pad_shape = array_ops.concat([
+ params_shape[:num_batch_dimensions],
+ [1],
+ params_shape[num_batch_dimensions + 1:params.shape.ndims]
+ ], 0)
+ upper_bounds = params_shape[num_batch_dimensions]
+ pad = array_ops.broadcast_to(default_value, pad_shape)
+
+ # Add `default_value` as the first value in the innermost (ragged) rank.
+ pad = math_ops.cast(pad, params.dtype)
+ padded_params = array_ops.concat(
+ [pad, params], axis=num_batch_dimensions)
+
+ # Adjust the indices by substituting out-of-bound indices to the
+ # default-value index (which is the first element)
+ shifted_indices = indices + 1
+ is_out_of_bounds = (indices < 0) | (indices > upper_bounds)
+ adjusted_indices = ragged_where_op.where(
+ is_out_of_bounds,
+ x=array_ops.zeros_like(indices), y=shifted_indices,
+ )
+ return array_ops.batch_gather(
+ params=padded_params, indices=adjusted_indices, name=name)
+
+
+def _get_pad_shape(params, indices):
+ """Gets the RaggedTensorDynamicShape for the pad tensor."""
+ num_batch_dimensions = indices.shape.ndims - 1
+ params_shape = ragged_tensor_shape.RaggedTensorDynamicShape.from_tensor(
+ params)
+
+ # We want to create a pad tensor that can be concatenated with the params.
+ if params.shape.ndims == indices.shape.ndims:
+ # When params and indices are the same rank, the shape of the pad tensor is
+ # almost identical to params, except the last dimension which has size = 1.
+ if params_shape.num_inner_dimensions is 0:
+ pad_dims = params_shape.partitioned_dim_sizes[:-1] + (
+ array_ops.ones_like(params_shape.partitioned_dim_sizes[-1]),)
+ return ragged_tensor_shape.RaggedTensorDynamicShape(
+ pad_dims, [])
+ else:
+ return ragged_tensor_shape.RaggedTensorDynamicShape(
+ params_shape.partitioned_dim_sizes,
+ array_ops.concat([params_shape.inner_dim_sizes[:-1], [1]], axis=0))
+ else:
+ # When the rank of indices < params, the pad has the same dimension as
+ # params up to the 'num_batch_dimensions' rank. Every dimension after that
+ # has size 1.
+ pad_dims = None
+ if num_batch_dimensions == 0:
+ pad_dims = (constant_op.constant(1, dtype=dtypes.int64),) + (
+ constant_op.constant([1], dtype=dtypes.int64),) * (
+ params_shape.num_partitioned_dimensions -
+ num_batch_dimensions - 1)
+ else:
+ batch_dimensions = params_shape.partitioned_dim_sizes[
+ :num_batch_dimensions]
+ gather_dimension = params_shape.partitioned_dim_sizes[
+ num_batch_dimensions]
+ pad_dims = batch_dimensions + (
+ array_ops.ones_like(gather_dimension),) * (
+ params_shape.num_partitioned_dimensions - num_batch_dimensions)
+
+ return ragged_tensor_shape.RaggedTensorDynamicShape(
+ pad_dims, params_shape.inner_dim_sizes)
diff --git a/tensorflow/python/ops/ragged/segment_id_ops.py b/tensorflow/python/ops/ragged/segment_id_ops.py
index 42dc132..31e26e7 100644
--- a/tensorflow/python/ops/ragged/segment_id_ops.py
+++ b/tensorflow/python/ops/ragged/segment_id_ops.py
@@ -29,7 +29,7 @@
# For background on "segments" and "segment ids", see:
-# https://www.tensorflow.org/api_guides/python/math_ops#Segmentation
+# https://www.tensorflow.org/api_docs/python/tf/math#Segmentation
@tf_export("ragged.row_splits_to_segment_ids")
def row_splits_to_segment_ids(splits, name=None):
"""Generates the segmentation corresponding to a RaggedTensor `row_splits`.
@@ -64,7 +64,7 @@
# For background on "segments" and "segment ids", see:
-# https://www.tensorflow.org/api_guides/python/math_ops#Segmentation
+# https://www.tensorflow.org/api_docs/python/tf/math#Segmentation
@tf_export("ragged.segment_ids_to_row_splits")
def segment_ids_to_row_splits(segment_ids, num_segments=None, name=None):
"""Generates the RaggedTensor `row_splits` corresponding to a segmentation.
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index 40c3771..603baea 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -1085,8 +1085,107 @@
return True
+class _RNNCellWrapperV1(RNNCell):
+ """Base class for cells wrappers V1 compatibility.
+
+ This class along with `_RNNCellWrapperV2` allows to define cells wrappers that
+ are compatible with V1 and V2, and defines helper methods for this purpose.
+ """
+
+ def __init__(self, cell):
+ super(_RNNCellWrapperV1, self).__init__()
+ self._cell = cell
+ if isinstance(cell, checkpointable.Checkpointable):
+ self._track_checkpointable(self._cell, name="cell")
+
+ def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs):
+ """Calls the wrapped cell and performs the wrapping logic.
+
+ This method is called from the wrapper's `call` or `__call__` methods.
+
+ Args:
+ inputs: A tensor with wrapped cell's input.
+ state: A tensor or tuple of tensors with wrapped cell's state.
+ cell_call_fn: Wrapped cell's method to use for step computation (cell's
+ `__call__` or 'call' method).
+ **kwargs: Additional arguments.
+
+ Returns:
+ A pair containing:
+ - Output: A tensor with cell's output.
+ - New state: A tensor or tuple of tensors with new wrapped cell's state.
+ """
+ raise NotImplementedError
+
+ def __call__(self, inputs, state, scope=None):
+ """Runs the RNN cell step computation.
+
+ We assume that the wrapped RNNCell is being built within its `__call__`
+ method. We directly use the wrapped cell's `__call__` in the overridden
+ wrapper `__call__` method.
+
+ This allows to use the wrapped cell and the non-wrapped cell equivalently
+ when using `__call__`.
+
+ Args:
+ inputs: A tensor with wrapped cell's input.
+ state: A tensor or tuple of tensors with wrapped cell's state.
+ scope: VariableScope for the subgraph created in the wrapped cells'
+ `__call__`.
+
+ Returns:
+ A pair containing:
+
+ - Output: A tensor with cell's output.
+ - New state: A tensor or tuple of tensors with new wrapped cell's state.
+ """
+ return self._call_wrapped_cell(
+ inputs, state, cell_call_fn=self._cell.__call__, scope=scope)
+
+
+class _RNNCellWrapperV2(LayerRNNCell, _RNNCellWrapperV1):
+ """Base class for cells wrappers V2 compatibility.
+
+ This class along with `_RNNCellWrapperV1` allows to define cells wrappers that
+ are compatible with V1 and V2, and defines helper methods for this purpose.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super(_RNNCellWrapperV2, self).__init__(*args, **kwargs)
+ self._layers = [self._cell]
+
+ def call(self, inputs, state, **kwargs):
+ """Runs the RNN cell step computation.
+
+ When `call` is being used, we assume that the wrapper object has been built,
+ and therefore the wrapped cells has been built via its `build` method and
+ its `call` method can be used directly.
+
+ This allows to use the wrapped cell and the non-wrapped cell equivalently
+ when using `call` and `build`.
+
+ Args:
+ inputs: A tensor with wrapped cell's input.
+ state: A tensor or tuple of tensors with wrapped cell's state.
+ **kwargs: Additional arguments passed to the wrapped cell's `call`.
+
+ Returns:
+ A pair containing:
+
+ - Output: A tensor with cell's output.
+ - New state: A tensor or tuple of tensors with new wrapped cell's state.
+ """
+ return self._call_wrapped_cell(
+ inputs, state, cell_call_fn=self._cell.call, **kwargs)
+
+ def build(self, inputs_shape):
+ """Builds the wrapped cell."""
+ self._cell.build(inputs_shape)
+ self.built = True
+
+
@tf_export(v1=["nn.rnn_cell.DropoutWrapper"])
-class DropoutWrapper(RNNCell):
+class DropoutWrapper(_RNNCellWrapperV1):
"""Operator adding dropout to inputs and outputs of the given cell."""
def __init__(self, cell, input_keep_prob=1.0, output_keep_prob=1.0,
@@ -1156,7 +1255,7 @@
but not `callable`.
ValueError: if any of the keep_probs are not between 0 and 1.
"""
- super(DropoutWrapper, self).__init__()
+ super(DropoutWrapper, self).__init__(cell)
assert_like_rnncell("cell", cell)
if (dropout_state_filter_visitor is not None
@@ -1181,10 +1280,7 @@
else:
setattr(self, "_%s" % attr, tensor_prob)
- # Set cell, variational_recurrent, seed before running the code below
- self._cell = cell
- if isinstance(cell, checkpointable.Checkpointable):
- self._track_checkpointable(self._cell, name="cell")
+ # Set variational_recurrent, seed before running the code below
self._variational_recurrent = variational_recurrent
self._seed = seed
@@ -1291,16 +1387,13 @@
shallow_filtered_substructure, dropout,
*[shallow_filtered_substructure, values, recurrent_noise])
- def _call(self, inputs, state, call_fn, **kwargs):
- """Defines a helper method that runs the wrapped cell and applies dropout.
-
- This helper is called from the DropoutWrapper's `call` or `__call__`
- methods.
+ def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs):
+ """Runs the wrapped cell and applies dropout.
Args:
inputs: A tensor with wrapped cell's input.
state: A tensor or tuple of tensors with wrapped cell's state.
- call_fn: Wrapped cell's method to use for step computation (cell's
+ cell_call_fn: Wrapped cell's method to use for step computation (cell's
`__call__` or 'call' method).
**kwargs: Additional arguments.
@@ -1317,7 +1410,7 @@
inputs = self._dropout(inputs, "input",
self._recurrent_input_noise,
self._input_keep_prob)
- output, new_state = call_fn(inputs, state, **kwargs)
+ output, new_state = cell_call_fn(inputs, state, **kwargs)
if _should_dropout(self._state_keep_prob):
# Identify which subsets of the state to perform dropout on and
# which ones to keep.
@@ -1333,40 +1426,81 @@
self._output_keep_prob)
return output, new_state
- def __call__(self, inputs, state, scope=None):
- """Runs the cell with the declared dropouts.
-
- We assume that the wrapped RNNCell is being built within its `__call__`
- method. We directly use the wrapped cell's `__call__` in the overridden
- DropoutWrapper `__call__` method.
-
- This should allow to use the wrapped cell and the non-wrapped cell
- equivalently when using `__call__`.
-
- Args:
- inputs: A tensor with wrapped cell's input.
- state: A tensor or tuple of tensors with wrapped cell's state.
- scope: VariableScope for the subgraph created in the wrapped cells'
- `__call__`.
-
- Returns:
- A pair containing:
-
- - Output: A tensor with cell's output.
- - New state: A tensor or tuple of tensors with new wrapped cell's state.
- """
- return self._call(inputs, state, call_fn=self._cell.__call__, scope=scope)
-
@tf_export("rnn.DropoutWrapper", v1=[])
-class DropoutWrapperV2(LayerRNNCell, DropoutWrapper):
+class DropoutWrapperV2(_RNNCellWrapperV2, DropoutWrapper):
"""Operator adding dropout to inputs and outputs of the given cell."""
def __init__(self, cell, input_keep_prob=1.0, output_keep_prob=1.0,
state_keep_prob=1.0, variational_recurrent=False,
input_size=None, dtype=None, seed=None,
dropout_state_filter_visitor=None):
- """Runs init in Keras style scope to use Keras-style variable management."""
+ """Create a cell with added input, state, and/or output dropout.
+
+ If `variational_recurrent` is set to `True` (**NOT** the default behavior),
+ then the same dropout mask is applied at every step, as described in:
+
+ Y. Gal, Z Ghahramani. "A Theoretically Grounded Application of Dropout in
+ Recurrent Neural Networks". https://arxiv.org/abs/1512.05287
+
+ Otherwise a different dropout mask is applied at every time step.
+
+ Note, by default (unless a custom `dropout_state_filter` is provided),
+ the memory state (`c` component of any `LSTMStateTuple`) passing through
+ a `DropoutWrapper` is never modified. This behavior is described in the
+ above article.
+
+ Runs initialization in Keras style scope to use Keras-style variable
+ management.
+
+ Args:
+ cell: a LayerRNNCell, a projection to output_size is added to it.
+ input_keep_prob: unit Tensor or float between 0 and 1, input keep
+ probability; if it is constant and 1, no input dropout will be added.
+ output_keep_prob: unit Tensor or float between 0 and 1, output keep
+ probability; if it is constant and 1, no output dropout will be added.
+ state_keep_prob: unit Tensor or float between 0 and 1, output keep
+ probability; if it is constant and 1, no output dropout will be added.
+ State dropout is performed on the outgoing states of the cell.
+ **Note** the state components to which dropout is applied when
+ `state_keep_prob` is in `(0, 1)` are also determined by
+ the argument `dropout_state_filter_visitor` (e.g. by default dropout
+ is never applied to the `c` component of an `LSTMStateTuple`).
+ variational_recurrent: Python bool. If `True`, then the same
+ dropout pattern is applied across all time steps per run call.
+ If this parameter is set, `input_size` **must** be provided.
+ input_size: (optional) (possibly nested tuple of) `TensorShape` objects
+ containing the depth(s) of the input tensors expected to be passed in to
+ the `DropoutWrapper`. Required and used **iff**
+ `variational_recurrent = True` and `input_keep_prob < 1`.
+ dtype: (optional) The `dtype` of the input, state, and output tensors.
+ Required and used **iff** `variational_recurrent = True`.
+ seed: (optional) integer, the randomness seed.
+ dropout_state_filter_visitor: (optional), default: (see below). Function
+ that takes any hierarchical level of the state and returns
+ a scalar or depth=1 structure of Python booleans describing
+ which terms in the state should be dropped out. In addition, if the
+ function returns `True`, dropout is applied across this sublevel. If
+ the function returns `False`, dropout is not applied across this entire
+ sublevel.
+ Default behavior: perform dropout on all terms except the memory (`c`)
+ state of `LSTMCellState` objects, and don't try to apply dropout to
+ `TensorArray` objects:
+ ```
+ def dropout_state_filter_visitor(s):
+ if isinstance(s, LSTMCellState):
+ # Never perform dropout on the c state.
+ return LSTMCellState(c=False, h=True)
+ elif isinstance(s, TensorArray):
+ return False
+ return True
+ ```
+
+ Raises:
+ TypeError: if `cell` is not an `RNNCell`, or `keep_state_fn` is provided
+ but not `callable`.
+ ValueError: if any of the keep_probs are not between 0 and 1.
+ """
with base_layer.keras_style_scope():
super(DropoutWrapperV2, self).__init__(
@@ -1380,36 +1514,9 @@
seed=seed,
dropout_state_filter_visitor=dropout_state_filter_visitor)
- def build(self, inputs_shape):
- self._cell.build(inputs_shape)
- self.built = True
- def call(self, inputs, state, **kwargs):
- """Runs the cell with the declared dropouts.
-
- When `call` is being used, we assume that the DropoutWrapper object has
- been built and therefore the wrapped cells has been built via its `build`
- method and its `call` method can be used directly.
-
- This should allow to use the wrapped cell and the non-wrapped cell
- equivalently when using `call` and `build`.
-
- Args:
- inputs: A tensor with wrapped cell's input.
- state: A tensor or tuple of tensors with wrapped cell's state.
- **kwargs: Additional arguments passed to the wrapped cell's `call`.
-
- Returns:
- A pair containing:
-
- - Output: A tensor with cell's output.
- - New state: A tensor or tuple of tensors with new wrapped cell's state.
- """
- return self._call(inputs, state, call_fn=self._cell.call, **kwargs)
-
-
-@tf_export("nn.rnn_cell.ResidualWrapper")
-class ResidualWrapper(RNNCell):
+@tf_export(v1=["nn.rnn_cell.ResidualWrapper"])
+class ResidualWrapper(_RNNCellWrapperV1):
"""RNNCell wrapper that ensures cell inputs are added to the outputs."""
def __init__(self, cell, residual_fn=None):
@@ -1422,10 +1529,7 @@
Defaults to calling nest.map_structure on (lambda i, o: i + o), inputs
and outputs.
"""
- super(ResidualWrapper, self).__init__()
- self._cell = cell
- if isinstance(cell, checkpointable.Checkpointable):
- self._track_checkpointable(self._cell, name="cell")
+ super(ResidualWrapper, self).__init__(cell)
self._residual_fn = residual_fn
@property
@@ -1440,13 +1544,15 @@
with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
return self._cell.zero_state(batch_size, dtype)
- def __call__(self, inputs, state, scope=None):
+ def _call_wrapped_cell(self, inputs, state, cell_call_fn, **kwargs):
"""Run the cell and then apply the residual_fn on its inputs to its outputs.
Args:
inputs: cell inputs.
state: cell state.
- scope: optional cell scope.
+ cell_call_fn: Wrapped cell's method to use for step computation (cell's
+ `__call__` or 'call' method).
+ **kwargs: Additional arguments passed to the wrapped cell's `call`.
Returns:
Tuple of cell outputs and new state.
@@ -1455,7 +1561,7 @@
TypeError: If cell inputs and outputs have different structure (type).
ValueError: If cell inputs and outputs have different structure (value).
"""
- outputs, new_state = self._cell(inputs, state, scope=scope)
+ outputs, new_state = cell_call_fn(inputs, state, **kwargs)
# Ensure shapes match
def assert_shape_match(inp, out):
inp.get_shape().assert_is_compatible_with(out.get_shape())
@@ -1467,6 +1573,29 @@
return (res_outputs, new_state)
+@tf_export("rnn.ResidualWrapper", v1=[])
+class ResidualWrapperV2(_RNNCellWrapperV2, ResidualWrapper):
+ """RNNCell wrapper that ensures cell inputs are added to the outputs."""
+
+ def __init__(self, cell, residual_fn=None):
+ """Constructs a `ResidualWrapperV2` for `cell`.
+
+ Runs initialization in Keras style scope to use Keras-style variable
+ management.
+
+ Args:
+ cell: An instance of `LayerRNNCell`.
+ residual_fn: (Optional) The function to map raw cell inputs and raw cell
+ outputs to the actual cell outputs of the residual network.
+ Defaults to calling nest.map_structure on (lambda i, o: i + o), inputs
+ and outputs.
+ """
+
+ with base_layer.keras_style_scope():
+ super(ResidualWrapperV2, self).__init__(
+ cell=cell, residual_fn=residual_fn)
+
+
@tf_export("nn.rnn_cell.DeviceWrapper")
class DeviceWrapper(RNNCell):
"""Operator that ensures an RNNCell runs on a particular device."""
diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py
index 41a4814..1a11c33 100644
--- a/tensorflow/python/ops/tensor_array_ops.py
+++ b/tensorflow/python/ops/tensor_array_ops.py
@@ -542,14 +542,20 @@
def read(self, index, name=None):
"""See TensorArray."""
- value = list_ops.tensor_list_get_item(
- input_handle=self._flow,
- index=index,
- element_dtype=self._dtype,
- name=name)
- if self._element_shape:
- value.set_shape(self._element_shape[0].dims)
- return value
+ with ops.name_scope(name, "TensorArrayV2Read", [self._flow, index]):
+ if self._element_shape:
+ element_shape = self._element_shape[0]
+ else:
+ element_shape = tensor_shape.TensorShape(None)
+ value = list_ops.tensor_list_get_item(
+ input_handle=self._flow,
+ index=index,
+ element_dtype=self._dtype,
+ element_shape=element_shape,
+ name=name)
+ if self._element_shape:
+ value.set_shape(self._element_shape[0].dims)
+ return value
@tf_should_use.should_use_result
def write(self, index, value, name=None):
@@ -819,7 +825,7 @@
if self._infer_shape:
if self._element_shape is None:
self._element_shape = value.shape
- elif self._element_shape != value.shape:
+ elif not self._element_shape.is_compatible_with(value.shape):
raise ValueError("Incompatible shape for value (%s), expected (%s)" %
(value.shape.as_list(), self._element_shape.as_list()))
diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py
index 0e427d3..68f1cba 100644
--- a/tensorflow/python/ops/while_v2.py
+++ b/tensorflow/python/ops/while_v2.py
@@ -39,10 +39,11 @@
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import gen_functional_ops
from tensorflow.python.ops import gen_resource_variable_ops
-from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import gradients_util
from tensorflow.python.ops import list_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import tensor_array_ops
+from tensorflow.python.ops import while_v2_indexed_slices_rewriter
from tensorflow.python.util import nest
# pylint: disable=protected-access
@@ -122,40 +123,34 @@
cond_graph = func_graph_module.func_graph_from_py_func(
cond_name,
wrapped_cond,
- loop_vars, {},
+ [], # We provide signature instead of args.
+ {},
signature=_build_signature(loop_vars, shape_invariants),
func_graph=util.WhileCondFuncGraph(
cond_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access
add_control_dependencies=add_control_dependencies)
- # Add external_captures of cond to the list of loop vars.
- # Note that external tensors will be treated as loop invariants, i.e.,
- # the value of that tensor in each iteration is the same as it was at the
- # beginning of the loop execution.
- loop_vars = loop_vars + cond_graph.external_captures
- shape_invariants = shape_invariants + type(shape_invariants)(
- [t.shape for t in cond_graph.external_captures])
-
def wrapped_body(loop_counter, *args):
"""Loop body augmented with counter update.
Args:
loop_counter: Loop counter which needs to be incremented in the body.
*args: List of args
- args[:len_orig_loop_vars] - Args for the original loop body.
- args[len_orig_loop_vars:] - External captures of cond. These get
- passed through as is.
Returns:
A list of tensors the same length as args.
"""
+ # Capture the tensors already captured in cond_graph so that they appear
+ # in the same order in body_graph.external_captures.
+ for t in cond_graph.external_captures:
+ ops.get_default_graph().capture(t)
+
# Convert the flow variables in `args` to TensorArrays. `args` should
# already have the same structure as `orig_loop_vars` but currently there
# is no nest.zip so we call `_pack_sequence_as` which flattens both
# `orig_loop_vars` and `args`, converts flows in `args` to TensorArrays
# and packs it into the structure of `orig_loop_vars`.
- outputs = body(
- *_pack_sequence_as(orig_loop_vars, args[:len_orig_loop_vars]))
+ outputs = body(*_pack_sequence_as(orig_loop_vars, args))
if not nest.is_sequence(outputs):
outputs = [outputs]
# Compare the structure of input and output of body converting the
@@ -164,17 +159,15 @@
outputs = _tensor_array_to_flow(outputs)
- # Return the external_captures of cond_graph as is, i.e., treat them as
- # loop invariants.
# TODO(srbs): Update lowering code to create _Enter nodes with
# is_constant=True for inputs that are directly passed to outputs.
- return [loop_counter + 1] + list(outputs) + list(
- args[len_orig_loop_vars:])
+ return [loop_counter + 1] + list(outputs)
body_graph = func_graph_module.func_graph_from_py_func(
body_name,
wrapped_body,
- loop_vars, {},
+ [], # We provide signature instead of args.
+ {},
signature=_build_signature(loop_vars, shape_invariants),
func_graph=util.WhileBodyFuncGraph(
body_name, collections=ops.get_default_graph()._collections), # pylint: disable=protected-access
@@ -188,17 +181,15 @@
# is_constant=True for inputs that are directly passed to outputs.
body_graph.outputs.extend(body_graph.internal_captures)
- # Capture `external_captures` of `body_graph` in `cond_graph` so that it
- # expects to receive those as arguments.
- # TODO(b/118457764): Dedup tensors that are captured in both the cond and
- # body. This logic already exists in cond_v2.
+ # Capture the extra `external_captures` of `body_graph` in `cond_graph` so
+ # that it expects to receive those as arguments.
with cond_graph.as_default():
- for external_capture in body_graph.external_captures:
- assert external_capture not in cond_graph.captures, (
- "Looks like both cond and body are capturing the same tensor %s. "
- "This is not supported yet. For now consider passing,"
- " this as a loop variable." % str(external_capture))
- cond_graph.capture(external_capture)
+ num_cond_captures = len(cond_graph.external_captures)
+ assert (cond_graph.external_captures ==
+ body_graph.external_captures[:num_cond_captures])
+ for body_capture in body_graph.external_captures[num_cond_captures:]:
+ assert body_capture not in cond_graph.captures
+ cond_graph.capture(body_capture)
# Make sure that the shapes of the loop outputs are compatible with the
# shape invariants, or the shapes of the loop vars if the invariants are not
@@ -303,6 +294,10 @@
while_op)
loop_vars = args + captured_inputs
+ # This modifies body_grad_graph.
+ loop_vars = while_v2_indexed_slices_rewriter.rewrite_grad_indexed_slices(
+ grads, body_grad_graph, loop_vars, while_op.inputs)
+
def grad_cond(counter, max_iters, *unused_args):
return counter < max_iters
@@ -320,26 +315,15 @@
output_shapes=[t.shape for t in body_grad_graph.outputs],
parallel_iterations=parallel_iterations,
name="%s_grad" % while_op.name)
+ grad_op = outputs[0].op
_copy_handle_data(body_grad_graph.outputs, outputs)
- util.maybe_set_lowering_attr(outputs[0].op)
- _maybe_set_maximum_iterations_attr(outputs[0].op, maximum_iterations)
+ util.maybe_set_lowering_attr(grad_op)
+ _maybe_set_maximum_iterations_attr(grad_op, maximum_iterations)
# See comment in while_loop.
outputs = [array_ops.identity(t) for t in outputs]
-
- # Set None as the output gradient for tensors with None input gradient.
- # outputs[0] is the loop counter.
- # outputs[1] is the total number of loop iterations.
- index = 2
- none_padded_outputs = []
- for g in grads:
- if g is None:
- none_padded_outputs.append(None)
- else:
- none_padded_outputs.append(outputs[index])
- index += 1
- return none_padded_outputs
+ return _get_structured_grad_output(outputs, grads, body_grad_graph)
def _preprocess_grad(grad, body_graph_output, while_op_output):
@@ -375,6 +359,8 @@
return grad
+# TODO(skyewm): make this return constants if op_output's shape is fully
+# defined (this can be done by checking the "shape" attr of resource vars).
def _zeros_like(op_output):
"""Like array_ops.zeros_like() but also accepts resource var handles."""
if op_output.dtype == dtypes.resource:
@@ -385,7 +371,7 @@
def _is_trainable(tensor):
"""Returns whether the given tensor is trainable."""
- if not gradients_impl.IsTrainable(tensor):
+ if not gradients_util.IsTrainable(tensor):
return False
# Special case: untrainable accumulator output. The gradients algorithm
@@ -396,7 +382,7 @@
if tensor.op.type == "TensorListPopBack" and tensor.value_index == 0:
assert tensor.dtype == dtypes.variant
element_type = tensor.op.get_attr("element_dtype")
- return gradients_impl.IsTrainable(element_type)
+ return gradients_util.IsTrainable(element_type)
return True
@@ -510,14 +496,15 @@
# Add the popped accumulators to the list of outputs.
for internal_capture in grad_func_graph.internal_captures:
if internal_capture in grad_func_graph.popped_tensor_lists:
- grad_func_graph.outputs.append(
- grad_func_graph.popped_tensor_lists[internal_capture])
+ new_output = grad_func_graph.popped_tensor_lists[internal_capture]
elif internal_capture.dtype == dtypes.resource:
- grad_func_graph.outputs.append(internal_capture)
+ new_output = internal_capture
else:
raise ValueError("Tensor %s is in list of internal_captures but is"
" neither a resource nor is in popped_tensor_lists." %
str(internal_capture))
+ grad_func_graph.outputs.append(new_output)
+ grad_func_graph.structured_outputs.append(new_output)
return grad_func_graph, args
@@ -547,7 +534,7 @@
# func_graph. The captured func_graph tensors are resolved to external tensors
# after the forward While op has been rewritten in _resolve_grad_captures.
# TODO(srbs): Mark GradientsHelper as public?
- grad_outs = gradients_impl._GradientsHelper(
+ grad_outs = gradients_util._GradientsHelper(
ys, xs, grad_ys=grad_ys, src_graph=func_graph,
unconnected_gradients="zero")
@@ -600,6 +587,45 @@
return new_capture_inputs
+def _get_structured_grad_output(outputs, grads, body_grad_graph):
+ """Returns the values that should be returned from the while grad function.
+
+ Args:
+ outputs: the raw Tensor outputs of the grad While op.
+ grads: the input gradients to the gradient function.
+ body_grad_graph: _WhileBodyGradFuncGraph.
+
+ Returns:
+ A list of gradient values. May include Nones.
+ """
+ result = []
+ # outputs[0] is the loop counter.
+ # outputs[1] is the total number of loop iterations.
+ outputs_idx = 2
+ structured_outputs_idx = 2
+ for g in grads:
+ # Set None as the output gradient for tensors with None input gradient.
+ if g is None:
+ result.append(None)
+ continue
+ output = body_grad_graph.structured_outputs[structured_outputs_idx]
+ structured_outputs_idx += 1
+ if isinstance(output, ops.IndexedSlices):
+ # TODO(skyewm): is there a more robust way to determine the order of
+ # flattened IndexedSlices components?
+ result.append(ops.IndexedSlices(
+ values=outputs[outputs_idx],
+ indices=outputs[outputs_idx + 1],
+ dense_shape=outputs[outputs_idx + 2]))
+ outputs_idx += 3
+ else:
+ assert isinstance(output, ops.Tensor)
+ result.append(outputs[outputs_idx])
+ outputs_idx += 1
+
+ return result
+
+
def _get_accumulator(tensor):
r"""Returns TensorList if any containing accumulated values of tensor.
@@ -741,9 +767,9 @@
"""
if (not whitelisted and tensor.graph is not self and
tensor.graph != self._forward_graph):
- raise ValueError("Attempting to capture tensor", str(tensor),
- " which is not in the forward graph but in ",
- _graph_name(tensor.graph), ".")
+ raise ValueError("Attempting to capture tensor %s which is not in the "
+ "forward graph but in %s." %
+ (str(tensor), _graph_name(tensor.graph)))
return super(_WhileBodyGradFuncGraph, self).capture(tensor, name)
def _capture_helper(self, tensor, name):
@@ -816,27 +842,97 @@
Tensor in this graph.
"""
assert tensor.dtype == dtypes.resource
- if tensor in self._forward_graph.inputs:
- index = self._forward_graph.inputs.index(tensor)
- elif tensor.op.type == "While":
- # Captured resources occur at the same index in the lists of inputs and
- # outputs of a while op. So we lookup the input of `tensor.op` at the
- # same index as the index of `tensor` in the `tensor.op.outputs`.
- index = self._forward_graph.inputs.index(
- tensor.op.inputs[tensor.value_index])
- else:
- raise ValueError(
- "Taking gradient of a while loop which creates "
- "a resource in its body is not supported: %s" % tensor)
- # This must be a loop invariant.
- assert self._forward_graph.inputs[index] == self._forward_graph.outputs[
- index], ("Resource tensors must be loop invariants %s." %
- self._forward_graph._while.inputs[index])
+
+ index = self._resource_input_index(
+ tensor.name,
+ [t.name for t in self._forward_graph.inputs],
+ {op.name: op.node_def for op in self._forward_graph.get_operations()},
+ self._forward_graph._functions)
+
+ input_placeholder = self._forward_graph.inputs[index]
tensor_in_outer_graph = self._forward_graph._while.inputs[index]
+
+ assert input_placeholder.dtype == dtypes.resource
+ assert tensor_in_outer_graph.dtype == dtypes.resource
+ # This must be a loop invariant.
+ assert input_placeholder == self._forward_graph.outputs[index], (
+ "Resource tensors must be loop invariants %s." %
+ tensor_in_outer_graph)
+
self._indirect_captures[tensor] = self.capture(
tensor_in_outer_graph, whitelisted=True)
return self._indirect_captures[tensor]
+ def _resource_input_index(self, tensor_name, input_names, node_defs,
+ functions):
+ """Returns the index of the input corresponding to `tensor_name`.
+
+ This method is used to find the corresponding index of an arbitrary resource
+ tensor in a function (the function could be a loop body). We assume that
+ resource handles are never created in functions, so that every resource
+ tensor can be traced back to a function input.
+
+ The awkward signature of this method is to make it work with both FuncGraphs
+ and FunctionDefs. This is so we can recurse on function call ops without
+ building the corresponding FuncGraph (note that even if a FuncGraph for a
+ FunctionDef already exists, the input/output/node names may have been
+ changed when the FuncGraph was serialized to the FunctionDef, which makes it
+ unusable with this algorithm).
+
+ Args:
+ tensor_name: the name of the resource tensor to be resolved to an input.
+ input_names: a list of the names of all inputs to the function.
+ node_defs: a dict mapping op name -> NodeDef for every op in the function.
+ functions: a dict mapping function name -> _EagerDefinedFunction.
+
+ Returns:
+ The index into input_names corresponding to `tensor_name`.
+ """
+ while tensor_name not in input_names:
+ # FunctionDefs and graphs use different tensor naming conventions.
+ parts = tensor_name.split(":")
+ if len(parts) == 3:
+ op_name, _, output_idx = parts
+ elif len(parts) == 2:
+ op_name, output_idx = parts
+ else:
+ assert len(parts) == 1
+ op_name = parts[0]
+ output_idx = 0
+ output_idx = int(output_idx)
+ node_def = node_defs[op_name]
+
+ if node_def.op == "While":
+ # Captured resources occur at the same index in the lists of inputs and
+ # outputs of a while op. So we lookup the input of `tensor.op` at the
+ # same index as the index of `tensor` in the `tensor.op.outputs`.
+ tensor_name = node_def.input[output_idx]
+ elif node_def.op in ("PartitionedCall", "StatefulPartitionedCall"):
+ # Functions output any captured resource tensors used by their
+ # gradients. `tensor_name` is one of these outputs from a nested
+ # function call, so recursively find the corresponding input in the
+ # nested FunctionDef.
+ func_name = node_def.attr["f"].func.name
+ fdef = functions[func_name].definition
+ output_arg_name = fdef.signature.output_arg[output_idx].name
+ output_tensor_name = fdef.ret[output_arg_name]
+ input_index = self._resource_input_index(
+ output_tensor_name,
+ [arg.name for arg in fdef.signature.input_arg],
+ {ndef.name: ndef for ndef in fdef.node_def},
+ functions)
+ tensor_name = node_def.input[input_index]
+ else:
+ # We assume there are no other ops types that will "forward" resource
+ # handles like this, so all other handles must have been created by the
+ # op. (Note that cond_v2 wraps resource handle outputs in optionals,
+ # which we'll end up accumulating).
+ raise ValueError(
+ "Taking gradient of a while loop which creates "
+ "a resource in its body is not supported: %s" % op_name)
+
+ return input_names.index(tensor_name)
+
def _check_shapes_compat(output_tensors, shape_invariants, input_tensors):
for (t, shape, input_t) in zip(output_tensors, shape_invariants,
@@ -857,7 +953,7 @@
assert len(cond_graph.outputs) == 1, (
"cond_graph has %d outputs; Expected: 1" % len(cond_graph.outputs))
assert len(body_graph.inputs) == num_flattened_loop_vars, (
- "body_graph takes %d inputs; Expected: %d" % (len(cond_graph.inputs),
+ "body_graph takes %d inputs; Expected: %d" % (len(body_graph.inputs),
num_flattened_loop_vars))
assert len(body_graph.outputs) == num_flattened_loop_vars, (
"body_graph has %d outputs; Expected: %d" % (len(body_graph.outputs),
diff --git a/tensorflow/python/ops/while_v2_indexed_slices_rewriter.py b/tensorflow/python/ops/while_v2_indexed_slices_rewriter.py
new file mode 100644
index 0000000..30e9709
--- /dev/null
+++ b/tensorflow/python/ops/while_v2_indexed_slices_rewriter.py
@@ -0,0 +1,279 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# =============================================================================
+"""Methods for rewriting while_v2 grad functions with IndexedSlices output."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import func_graph
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_resource_variable_ops
+from tensorflow.python.util import nest
+
+
+def rewrite_grad_indexed_slices(grads, body_grad_graph, loop_vars,
+ forward_inputs):
+ """Handles special case of IndexedSlices returned from while gradient.
+
+ Some gradient functions return IndexedSlices instead of a Tensor (e.g. the
+ gradient of Gather ops). When this happens in the gradient of a while body,
+ the resulting gradient body function will have mismatched inputs and outputs,
+ since the input is a single Tensor, but the IndexedSlices gets unnested into
+ three output Tensors.
+
+ This function fixes this by rewriting the gradient body to have three inputs
+ to match the three outputs, i.e., it effectively converts the input Tensor
+ into an input IndexedSlices. It also returns new `loop_vars` to reflect the
+ new inputs.
+
+ Args:
+ grads: the input gradient Tensors to the while gradient computation.
+ body_grad_graph: _WhileBodyGradFuncGraph.
+ loop_vars: list of Tensors. The inputs to body_grad_graph.
+ forward_inputs: list of Tensors. The (flat) inputs to the forward-pass
+ While op.
+
+ Returns:
+ The new loop_vars to pass to body_grad_graph.
+ """
+ # Match up body_grad_graph.structured_outputs with the corresponding
+ # forward_inputs.
+ #
+ # Note that we don't expect a gradient computation to have structured output
+ # (e.g. no nested lists), so no need to flatten
+ # body_grad_graph.structured_outputs. However, structured_outputs may still
+ # contain composite tensors such as IndexedSlices, unlike
+ # body_grad_graph.outputs, which contains flattened composite tensors.
+ inputs_with_grads = [t for g, t in zip(grads, forward_inputs)
+ if g is not None]
+ # Skip loop counter and total number of loop iterations.
+ structured_outputs = body_grad_graph.structured_outputs[2:]
+
+ for forward_input, output in zip(inputs_with_grads, structured_outputs):
+ if not isinstance(output, ops.IndexedSlices): continue
+
+ if forward_input.dtype == dtypes.resource:
+ # TODO(skyewm): In theory we should use this for all captured inputs, not
+ # just resource handles (which can only be captured). We can do this by
+ # checking that forward_input is passed straight through to its output.
+ loop_vars = _rewrite_input_as_indexed_slices(body_grad_graph, output,
+ forward_input, loop_vars)
+ else:
+ _rewrite_output_as_tensor(body_grad_graph, output)
+
+ return loop_vars
+
+
+def _rewrite_output_as_tensor(body_grad_graph, grad_output_slices):
+ """Rewrites grad_output_slices to be a Tensor output.
+
+ Args:
+ body_grad_graph: _WhileBodyGradFuncGraph.
+ grad_output_slices: IndexedSlices output of body_grad_graph.
+ """
+ with body_grad_graph.as_default():
+ new_output = ops.convert_to_tensor_v2(grad_output_slices)
+
+ idx = body_grad_graph.structured_outputs.index(grad_output_slices)
+ body_grad_graph.structured_outputs[idx] = new_output
+ body_grad_graph.outputs = func_graph.flatten(
+ body_grad_graph.structured_outputs)
+
+
+def _rewrite_input_as_indexed_slices(body_grad_graph, grad_output_slices,
+ forward_input, loop_vars):
+ """Rewrites grad_output_slices's corresponding input to be an IndexedSlices.
+
+ This rewrite requires that forward_input was captured in the forward loop,
+ i.e. is not a user-specified loop variable. This is important because the
+ rewrite assumes that forward_input is passed through to its corresponding
+ output unchanged. This assumption is used in _rewrite_input_as_indexed_slices,
+ which depends on the exact gradient structure produced by the input's fanout.
+
+ This can yield a more efficient computation than using
+ _rewrite_output_as_tensor, since it preserves the IndexedSlices structure
+ instead of converting the IndexedSlices to a dense Tensor.
+
+ Args:
+ body_grad_graph: _WhileBodyGradFuncGraph.
+ grad_output_slices: IndexedSlices output of body_grad_graph.
+ forward_input: the corresonding Tensor input to the forward loop.
+ loop_vars: list of Tensors. The inputs to body_grad_graph.
+
+ Returns:
+ The new loop_vars to pass to body_grad_graph.
+ """
+ # Create initial IndexedSlices that will be the input to the grad While
+ # op. This will start as zeros, and accumulate the IndexedSlices grad output.
+ # Note that because forward_input is captured and not a loop var, its incoming
+ # gradient should always be zero.
+ init_slices = _create_grad_indexed_slices_init(grad_output_slices,
+ forward_input)
+
+ # Create a new version of grad_output_slices's gradient computation that uses
+ # the new IndexedSlices input instead of the original Tensor input. We'll
+ # return the new computation and leave the old computation as dead code.
+ # TODO(skyewm): considering pruning body_grad_graph to remove the old
+ # computation.
+ with body_grad_graph.as_default():
+ input_slices = ops.IndexedSlices(
+ values=body_grad_graph.capture(init_slices.values, whitelisted=True),
+ indices=body_grad_graph.capture(init_slices.indices, whitelisted=True),
+ dense_shape=body_grad_graph.capture(init_slices.dense_shape,
+ whitelisted=True))
+
+ # Remove the captured tensors from the function inputs. We'll add them back
+ # at the correct index in _update_indexed_slices_param.
+ for t in _flatten(init_slices):
+ captured_t = body_grad_graph.captures.pop(t)
+ body_grad_graph.inputs.remove(captured_t)
+
+ new_output_slices = _rewrite_grad_indexed_slices_output(grad_output_slices,
+ input_slices)
+
+ # Update body_grad_graph's inputs and outputs to reflect the new
+ # IndexedSlices computation.
+ return _update_indexed_slices_param(
+ body_grad_graph, loop_vars, init_slices, input_slices, new_output_slices,
+ grad_output_slices)
+
+
+def _create_grad_indexed_slices_init(grad_output_slices, forward_input):
+ """Creates an IndexedSlices to pass as input to the while grad function.
+
+ Args:
+ grad_output_slices: IndexedSlices. The corresponding while grad function
+ output.
+ forward_input: Tensor. The corresonding input to the forward while op.
+
+ Returns:
+ Zeros IndexedSlices, created in current Graph.
+ """
+ assert isinstance(grad_output_slices, ops.IndexedSlices)
+ assert isinstance(forward_input, ops.Tensor)
+ values_out = grad_output_slices.values
+ indices_out = grad_output_slices.indices
+
+ # Create the initial values tensor.
+ if values_out.shape.is_fully_defined():
+ values_shape = tensor_shape.TensorShape([0] +
+ values_out.shape.as_list()[1:])
+ values = array_ops.zeros(values_shape, dtype=values_out.dtype,
+ name="values_init")
+ else:
+ if forward_input.dtype == dtypes.resource:
+ forward_shape = gen_resource_variable_ops.variable_shape(forward_input)
+ else:
+ forward_shape = array_ops.shape(forward_input)
+ values_shape = array_ops.concat([[0], forward_shape[1:]], 0)
+ values = array_ops.zeros(values_shape, dtype=values_out.dtype,
+ name="values_init")
+
+ # Create the initial indices tensor.
+ indices = constant_op.constant([], indices_out.dtype, name="indices_init")
+
+ # Create the initial dense_shape tensor. We assume is the same shape as
+ # forward_input, since captured tensors don't change shape across loop
+ # iterations.
+ if forward_input.dtype == dtypes.resource:
+ shape = gen_resource_variable_ops.variable_shape(forward_input,
+ name="shape_init")
+ else:
+ shape = array_ops.shape(forward_input, name="shape_init")
+
+ return ops.IndexedSlices(values=values, indices=indices, dense_shape=shape)
+
+
+def _rewrite_grad_indexed_slices_output(old_output_slices, new_input_slices):
+ """Creates a new verson of old_output_slices with new_input_slices as input.
+
+ This method assumes that old_output_slices.{values,indices} are produced by
+ concatenating the incoming gradient Tensor input with the IndexedSlices
+ produced by the gradient computation of the while body. See
+ gradients_impl._AggregateIndexedSlicesGradients for where these concats are
+ constructed. We build new concats that use new_input_slices instead of the
+ original Tensor input.
+
+ Args:
+ old_output_slices: original IndexedSlices output of while gradient.
+ new_input_slices: new IndexedSlices to use as input to while gradient.
+
+ Returns:
+ A new IndexedSlices to replace old_output_slices.
+ """
+
+ def rewrite(old_output, new_input):
+ assert old_output.type == "Identity"
+ concat_op = old_output.inputs[0].op
+ assert concat_op.type == "ConcatV2"
+ # Don't include axis arg
+ old_concat_args = concat_op.inputs[:-1]
+ # We assume that the original gradient input was the first argument to the
+ # concat op.
+ # TODO(skyewm): do this in a more robust way.
+ return array_ops.concat([new_input] + old_concat_args[1:], 0)
+
+ values = rewrite(old_output_slices.values.op, new_input_slices.values)
+ indices = rewrite(old_output_slices.indices.op, new_input_slices.indices)
+ return ops.IndexedSlices(values=values, indices=indices,
+ dense_shape=new_input_slices.dense_shape)
+
+
+def _update_indexed_slices_param(graph, loop_vars, init_slices, input_slices,
+ output_slices, old_output_slices):
+ """Updates graph with new IndexedSlices input/output.
+
+ Updates graph's metadata to output the gradient computation defined by
+ init_slices, input_slices, and output_slices, instead of outputting
+ old_output_slices. Also returns a new version of loop_vars with init_slices
+ replacing the old input.
+
+ Args:
+ graph: _WhileBodyGradFuncGraph.
+ loop_vars: the inputs to graph.
+ init_slices: the new IndexedSlices to use as input to graph.
+ input_slices: the new IndexedSlices in graph that should be fed by
+ init_slices.
+ output_slices: the new IndexedSlices in graph that should be the
+ corresonding output to input_slices.
+ old_output_slices: the IndexedSlices in graph that are currently
+ being output.
+
+ Returns:
+ New loop_vars to pass to graph.
+ """
+ structured_idx = graph.structured_outputs.index(old_output_slices)
+ # We assume that the component tensors of old_output_slices appear
+ # sequentially in graph.outputs. We use the first of these tensors
+ # as the reference index.
+ flat_idx = graph.outputs.index(func_graph.flatten(old_output_slices)[0])
+
+ graph.structured_outputs[structured_idx] = output_slices
+ graph.outputs = func_graph.flatten(
+ graph.structured_outputs)
+
+ graph.inputs = (graph.inputs[:flat_idx] + _flatten(input_slices) +
+ graph.inputs[flat_idx + 1:])
+
+ return loop_vars[:flat_idx] + _flatten(init_slices) + loop_vars[flat_idx + 1:]
+
+
+def _flatten(arg):
+ return nest.flatten(arg, expand_composites=True)
diff --git a/tensorflow/python/profiler/internal/run_metadata_test.py b/tensorflow/python/profiler/internal/run_metadata_test.py
index f96d721..9e92a8f 100644
--- a/tensorflow/python/profiler/internal/run_metadata_test.py
+++ b/tensorflow/python/profiler/internal/run_metadata_test.py
@@ -50,7 +50,7 @@
dev = dev[dev.find('cpu:'):]
elif dev.find('gpu:') > 0:
dev = dev[dev.find('gpu:'):]
- else:
+ elif '/host:cpu' not in dev:
assert False, 'Unrecognized device name: %s' % dev
for node_stat in dev_stat.node_stats:
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index bae20ac..dd74f12 100755
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -69,6 +69,7 @@
%rename("%s") TFE_DeleteContextOptions;
%rename("%s") TFE_Py_TensorShapeSlice;
%rename("%s") TFE_Py_TensorShapeOnDevice;
+%rename("%s") TFE_Py_EnableInteractivePythonLogging;
%rename("%s") TFE_ContextStartStep;
%rename("%s") TFE_ContextEndStep;
%rename("%s") TFE_Py_RegisterVSpace;
diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD
index 5d08a40..76670bf 100644
--- a/tensorflow/python/saved_model/BUILD
+++ b/tensorflow/python/saved_model/BUILD
@@ -262,6 +262,16 @@
)
py_library(
+ name = "signature_serialization",
+ srcs = [
+ "signature_serialization.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ],
+)
+
+py_library(
name = "save",
srcs = [
"save.py",
@@ -271,10 +281,12 @@
":builder",
":constants",
":function_serialization",
+ ":nested_structure_coder",
":revived_types",
":saved_object_graph_py",
":signature_constants",
":signature_def_utils",
+ ":signature_serialization",
":tag_constants",
":utils",
"//tensorflow/core:protos_all_py",
@@ -286,14 +298,16 @@
"//tensorflow/python:framework_ops",
"//tensorflow/python:lib",
"//tensorflow/python:resource_variable_ops",
- "//tensorflow/python:tensor_spec",
"//tensorflow/python:util",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/eager:function",
"//tensorflow/python/training/checkpointable:base",
+ "//tensorflow/python/training/checkpointable:graph_view",
+ "//tensorflow/python/training/checkpointable:object_identity",
"//tensorflow/python/training/checkpointable:tracking",
"//tensorflow/python/training/checkpointable:util",
+ "//tensorflow/python/training/saving:functional_saver",
],
)
@@ -321,13 +335,22 @@
":constants",
":function_deserialization",
":loader",
+ ":nested_structure_coder",
":revived_types",
":saved_object_graph_py",
":utils",
- "//tensorflow/python:function",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:init_ops",
"//tensorflow/python:lib",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:tensor_util",
"//tensorflow/python:util",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/training/checkpointable:base",
+ "//tensorflow/python/training/checkpointable:graph_view",
"//tensorflow/python/training/checkpointable:tracking",
+ "//tensorflow/python/training/checkpointable:util",
],
)
diff --git a/tensorflow/python/saved_model/function_deserialization.py b/tensorflow/python/saved_model/function_deserialization.py
index e82e642..992b62a 100644
--- a/tensorflow/python/saved_model/function_deserialization.py
+++ b/tensorflow/python/saved_model/function_deserialization.py
@@ -105,18 +105,15 @@
super(RestoredFunction, self).__init__(
python_function, name, autograph=False)
self._concrete_functions = concrete_functions
- # TODO(vbardiovsky): This does not propagate to stateful and stateless
- # functions of the RestoredFunction, which will have seen only defunned
- # restored_function_body(*args, **kwargs). Therefore get_concrete_function()
- # called on RestoredFunction will not work properly.
+ # This does not propagate to stateful and stateless functions of the
+ # RestoredFunction, which will have seen only defunned
+ # restored_function_body(*args, **kwargs). That's why we have to
+ # canonicalize inputs inside restored_function_body.
self._function_spec = function_spec
def _list_all_concrete_functions_for_serialization(self):
return self._concrete_functions
- def get_concrete_function(self, *args, **kwargs):
- raise NotImplementedError()
-
def recreate_function(saved_function, concrete_functions):
"""Creates a `Function` from a `SavedFunction`.
diff --git a/tensorflow/python/saved_model/load.py b/tensorflow/python/saved_model/load.py
index 4b9ba02..e8474f2 100644
--- a/tensorflow/python/saved_model/load.py
+++ b/tensorflow/python/saved_model/load.py
@@ -22,6 +22,7 @@
import os
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import init_ops
@@ -34,6 +35,8 @@
from tensorflow.python.saved_model import revived_types
from tensorflow.python.saved_model import saved_object_graph_pb2
from tensorflow.python.saved_model import utils_impl as saved_model_utils
+from tensorflow.python.training.checkpointable import base
+from tensorflow.python.training.checkpointable import graph_view
from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.training.checkpointable import util
from tensorflow.python.util import compat
@@ -127,9 +130,34 @@
setattr(type(obj), "__call__", _call_attribute)
def _restore_checkpoint(self):
+ """Load state from checkpoint into the deserialized objects."""
variables_path = saved_model_utils.get_variables_path(self._export_dir)
- saver = util.CheckpointableSaver(self.get(0))
- saver.restore(variables_path).assert_consumed()
+ # TODO(andresp): Clean use of private methods of CheckpointableSaver.
+ # pylint: disable=protected-access
+ saver = util.CheckpointableSaver(graph_view.ObjectGraphView(self.get(0)))
+ saver._file_prefix_placeholder = constant_op.constant(variables_path)
+ load_status = saver.restore(variables_path)
+ load_status.assert_existing_objects_matched()
+ checkpoint = load_status._checkpoint
+
+ # When running in eager mode, the `restore` call above has already run and
+ # restored the state of checkpointables, call `position.restore_ops()` will
+ # return an empty list as there is nothing left to do. In graph mode, that
+ # will return the list of ops that must run to restore the object on that
+ # position. We have to wire them in the initializers of the objects so that
+ # they get initialized properly when using common practices (e.g. the ones
+ # used by ManagedSession) without further user action.
+ for object_id, obj in dict(checkpoint.object_by_proto_id).items():
+ position = base.CheckpointPosition(checkpoint=checkpoint,
+ proto_id=object_id)
+ restore_ops = position.restore_ops()
+ if restore_ops:
+ if resource_variable_ops.is_resource_variable(obj):
+ obj._initializer_op = restore_ops
+ else:
+ raise NotImplementedError(
+ ("Missing functionality to restore state of object "
+ "%r from the checkpoint." % obj))
def get(self, node_id):
return self._nodes[node_id]
@@ -210,10 +238,11 @@
compat.as_bytes("object_graph.pb"))
if file_io.file_exists(object_graph_filename):
object_graph_proto = _load_saved_object_graph_proto(object_graph_filename)
- loader = _Loader(object_graph_proto,
- saved_model_proto,
- export_dir)
- root = loader.get(0)
+ with ops.init_scope():
+ loader = _Loader(object_graph_proto,
+ saved_model_proto,
+ export_dir)
+ root = loader.get(0)
else:
raise NotImplementedError(
"Currently only SavedModels exported with `tf.saved_model.save` may be "
diff --git a/tensorflow/python/saved_model/load_test.py b/tensorflow/python/saved_model/load_test.py
index 299b6ae..1edb089 100644
--- a/tensorflow/python/saved_model/load_test.py
+++ b/tensorflow/python/saved_model/load_test.py
@@ -29,6 +29,7 @@
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
@@ -37,6 +38,7 @@
from tensorflow.python.ops import variables
from tensorflow.python.saved_model import load
from tensorflow.python.saved_model import save
+from tensorflow.python.training import monitored_session
from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.training.checkpointable import util
@@ -310,7 +312,7 @@
imported = self.cycle(root, cycles)
with self.assertRaisesRegexp(AssertionError,
- "Could not find matching function to call.*"):
+ "Could not find matching function to call"):
imported.f(input2)
self.assertEqual(31, imported.f(input1).numpy())
@@ -536,6 +538,50 @@
x = constant_op.constant(1.0)
self.assertAllEqual(imported(x).numpy(), 3.0)
+ def test_load_in_graph_mode(self, cycles):
+ root = tracking.AutoCheckpointable()
+ root.v1 = variables.Variable(1.)
+ root.v2 = variables.Variable(2.)
+ root.f = def_function.function(
+ lambda x: root.v2 * x,
+ input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
+
+ if cycles > 1:
+ root = self.cycle(root, cycles - 1)
+ path = tempfile.mkdtemp(prefix=self.get_temp_dir())
+ save.save(root, path)
+
+ with ops.Graph().as_default():
+ imported = load.load(path)
+ var_v1 = imported.v1
+ output = imported.f(constant_op.constant(2.))
+ with monitored_session.MonitoredSession() as sess:
+ self.assertEqual(1.0, sess.run(var_v1))
+ self.assertEqual(4.0, sess.run(output))
+
+ def test_load_in_func_graph(self, cycles):
+ root = tracking.AutoCheckpointable()
+ root.v1 = variables.Variable(1.)
+ root.v2 = variables.Variable(2.)
+ root.f = def_function.function(
+ lambda x: root.v2 * x,
+ input_signature=[tensor_spec.TensorSpec(None, dtypes.float32)])
+
+ if cycles > 1:
+ root = self.cycle(root, cycles - 1)
+ path = tempfile.mkdtemp(prefix=self.get_temp_dir())
+ save.save(root, path)
+
+ closure = tracking.AutoCheckpointable()
+ @def_function.function
+ def func(x):
+ if not hasattr(closure, "model"):
+ closure.model = load.load(path)
+ return closure.model.f(x)
+
+ inputs = constant_op.constant(2.)
+ self.assertEqual(4.0, func(inputs).numpy())
+
def test_soft_matching(self, cycles):
@def_function.function(
@@ -566,6 +612,36 @@
self.assertAllEqual([2, 4, 6],
imported.f(constant_op.constant([1, 2, 3])).numpy())
+ def test_get_concrete_function(self, cycles):
+
+ @def_function.function
+ def func(x, training=False):
+ if training:
+ return 2 * x
+ else:
+ return 3 * x
+
+ func.get_concrete_function(
+ tensor_spec.TensorSpec([None], dtypes.int32), True)
+ func.get_concrete_function(tensor_spec.TensorSpec([None], dtypes.float32))
+
+ root = tracking.AutoCheckpointable()
+ root.f = func
+
+ imported = self.cycle(root, cycles)
+
+ concrete = imported.f.get_concrete_function(
+ training=True, x=tensor_spec.TensorSpec([None], dtypes.int32))
+
+ self.assertAllEqual([2, 4, 6, 8],
+ concrete(x=constant_op.constant([1, 2, 3, 4])).numpy())
+ with self.assertRaisesRegexp(AssertionError,
+ "Could not find matching function to call"):
+ imported.f.get_concrete_function(
+ tensor_spec.TensorSpec([None], dtypes.int32))
+ imported.f.get_concrete_function(
+ tensor_spec.TensorSpec([None], dtypes.int32), True)
+
def test_concrete_function(self, cycles):
@def_function.function(
@@ -802,5 +878,68 @@
self.assertEqual(
2, imported.table_user(constant_op.constant("gamma")).numpy())
+ def test_functions_accessed_once(self, cycles):
+
+ class Exported(tracking.AutoCheckpointable):
+
+ def __init__(self):
+ self._counter = 0
+
+ @property
+ def make_func(self):
+ @def_function.function
+ def f():
+ return constant_op.constant(self._counter)
+ f.get_concrete_function() # force a trace
+ self._counter += 1
+ return f
+
+ exported = Exported()
+ imported = self.cycle(exported, cycles)
+ self.assertEqual(0, imported.make_func().numpy())
+ self.assertEqual(1, exported.make_func().numpy())
+
+ def test_overwritten_signatures_error(self, cycles):
+ exported = tracking.AutoCheckpointable()
+ exported.f = def_function.function(lambda: constant_op.constant(1.))
+ imported = self.cycle(
+ exported, cycles,
+ signatures={"key": exported.f.get_concrete_function()})
+ self.assertEqual(1., imported.signatures["key"]()["output_0"].numpy())
+ imported.signatures = {"key1": imported.signatures["key"]}
+ with self.assertRaisesRegexp(ValueError, "signatures"):
+ save.save(imported, tempfile.mkdtemp(prefix=self.get_temp_dir()))
+
+ def test_signature_loading(self, cycles):
+
+ class Exported(tracking.AutoCheckpointable):
+
+ def __init__(self):
+ self.v = variables.Variable(3.)
+
+ @def_function.function
+ def do(self, x):
+ return self.v * x
+
+ exported = Exported()
+ imported = self.cycle(
+ exported,
+ signatures=exported.do.get_concrete_function(
+ tensor_spec.TensorSpec(None, dtypes.float32)))
+ for _ in range(cycles - 1):
+ imported = self.cycle(imported, signatures=imported.signatures)
+ self.assertEqual(["serving_default"], list(imported.signatures.keys()))
+ imported_function = imported.signatures["serving_default"]
+ two = constant_op.constant(2.)
+ self.assertEqual(6., imported_function(x=two)["output_0"].numpy())
+ imported.v.assign(4.)
+ self.assertEqual(8., imported_function(x=two)["output_0"].numpy())
+ with self.assertRaisesRegexp(TypeError, "positional"):
+ imported_function(two)
+ with self.assertRaises(TypeError):
+ # The signatures mapping is immutable
+ imported.signatures["random_key"] = 3
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/saved_model/model_utils/__init__.py b/tensorflow/python/saved_model/model_utils/__init__.py
index 84540ba..3f54c96 100644
--- a/tensorflow/python/saved_model/model_utils/__init__.py
+++ b/tensorflow/python/saved_model/model_utils/__init__.py
@@ -25,4 +25,5 @@
from tensorflow.python.saved_model.model_utils.export_utils import get_export_outputs
from tensorflow.python.saved_model.model_utils.export_utils import get_temp_export_dir
from tensorflow.python.saved_model.model_utils.export_utils import get_timestamped_export_dir
+from tensorflow.python.saved_model.model_utils.export_utils import SIGNATURE_KEY_MAP
# pylint: enable=wildcard-import
diff --git a/tensorflow/python/saved_model/model_utils/export_utils.py b/tensorflow/python/saved_model/model_utils/export_utils.py
index 4f89337..343a336 100644
--- a/tensorflow/python/saved_model/model_utils/export_utils.py
+++ b/tensorflow/python/saved_model/model_utils/export_utils.py
@@ -41,6 +41,18 @@
mode_keys.ModeKeys.TEST: [tag_constants.EVAL],
}
+# For every exported mode, a SignatureDef map should be created using the
+# functions `export_outputs_for_mode` and `build_all_signature_defs`. By
+# default, this map will contain a single Signature that defines the input
+# tensors and output predictions, losses, and/or metrics (depending on the mode)
+# The default keys used in the SignatureDef map are defined below.
+SIGNATURE_KEY_MAP = {
+ mode_keys.ModeKeys.PREDICT:
+ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
+ mode_keys.ModeKeys.TRAIN:
+ signature_constants.DEFAULT_TRAIN_SIGNATURE_DEF_KEY,
+ mode_keys.ModeKeys.TEST: signature_constants.DEFAULT_EVAL_SIGNATURE_DEF_KEY
+}
_SINGLE_FEATURE_DEFAULT_NAME = 'feature'
_SINGLE_RECEIVER_DEFAULT_NAME = 'input'
@@ -262,18 +274,21 @@
Raises:
ValueError: if an appropriate ExportOutput cannot be found for the mode.
"""
- # TODO(b/113185250): move all model export helper functions into an util file.
+ if mode not in SIGNATURE_KEY_MAP:
+ raise ValueError(
+ 'Export output type not found for mode: {}. Expected one of: {}.\n'
+ 'One likely error is that V1 Estimator Modekeys were somehow passed to '
+ 'this function. Please ensure that you are using the new ModeKeys.'
+ .format(mode, SIGNATURE_KEY_MAP.keys()))
+ signature_key = SIGNATURE_KEY_MAP[mode]
if mode == mode_keys.ModeKeys.PREDICT:
return get_export_outputs(serving_export_outputs, predictions)
elif mode == mode_keys.ModeKeys.TRAIN:
- return {mode: export_output_lib.TrainOutput(
- loss=loss, predictions=predictions, metrics=metrics)}
- elif mode == mode_keys.ModeKeys.TEST:
- return {mode: export_output_lib.EvalOutput(
+ return {signature_key: export_output_lib.TrainOutput(
loss=loss, predictions=predictions, metrics=metrics)}
else:
- raise ValueError(
- 'Export output type not found for mode: {}'.format(mode))
+ return {signature_key: export_output_lib.EvalOutput(
+ loss=loss, predictions=predictions, metrics=metrics)}
def get_export_outputs(export_outputs, predictions):
diff --git a/tensorflow/python/saved_model/revived_types.py b/tensorflow/python/saved_model/revived_types.py
index 8f82039..ae06320 100644
--- a/tensorflow/python/saved_model/revived_types.py
+++ b/tensorflow/python/saved_model/revived_types.py
@@ -26,8 +26,7 @@
"""Holds information about one version of a revived type."""
def __init__(self, object_factory, version, min_producer_version,
- min_consumer_version, bad_consumers=None, setter=setattr,
- getter=getattr, attribute_extractor=dir):
+ min_consumer_version, bad_consumers=None, setter=setattr):
"""Identify a revived type version.
Args:
@@ -60,16 +59,8 @@
addition to any version less than `min_consumer_version`).
setter: A callable with the same signature as `setattr` to use when adding
dependencies to generated objects.
- getter: A callable with the same signature as `getattr` to use when
- retrieving items from objects of this type. Used along with
- `attribute_extractor` to find functions, which are not Checkpointable
- objects and so not regular dependencies.
- attribute_extractor: A callable equivalent of the builtin `dir`, used for
- listing items in this container (if any).
"""
self.setter = setter
- self.getter = getter
- self.attribute_extractor = attribute_extractor
self.identifier = None # Set after registration
self._object_factory = object_factory
self.version = version
@@ -146,15 +137,6 @@
_TYPE_IDENTIFIERS.append(identifier)
-def get_attribute_extractors(obj):
- """Get a `dir` and `getattr` equivalent for use with `obj`."""
- for identifier in _TYPE_IDENTIFIERS:
- predicate, versions = _REVIVED_TYPE_REGISTRY[identifier]
- if predicate(obj):
- return versions[0].attribute_extractor, versions[0].getter
- return dir, getattr
-
-
def serialize(obj):
"""Create a SavedUserObject from a checkpointable object."""
for identifier in _TYPE_IDENTIFIERS:
diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py
index 7eeb146..5b4f016 100644
--- a/tensorflow/python/saved_model/save.py
+++ b/tensorflow/python/saved_model/save.py
@@ -44,16 +44,18 @@
from tensorflow.python.saved_model import saved_object_graph_pb2
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import signature_def_utils
+from tensorflow.python.saved_model import signature_serialization
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.saved_model import utils_impl
from tensorflow.python.training.checkpointable import base
+from tensorflow.python.training.checkpointable import graph_view
+from tensorflow.python.training.checkpointable import object_identity
from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.training.checkpointable import util
+from tensorflow.python.training.saving import functional_saver
from tensorflow.python.util import compat
-from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import tf_export
-DEFAULT_SIGNATURE_ATTR = "_default_save_signature"
_UNCOPIABLE_DTYPES = frozenset((dtypes.resource, dtypes.variant))
@@ -63,38 +65,88 @@
"_CapturedConstant", ["eager_tensor", "graph_tensor"])
+class _AugmentedGraphView(graph_view.ObjectGraphView):
+ """An extendable graph which also tracks functions attached to objects.
+
+ Extensions through `add_object` appear in the object graph and any checkpoints
+ generated from it, even if they are not dependencies of the node they were
+ attached to in the saving program. For example a `.signatures` attribute is
+ added to exported SavedModel root objects without modifying the root object
+ itself.
+
+ Also tracks functions attached to objects in the graph, through the caching
+ `list_functions` method. Enumerating functions only through this method
+ ensures that we get a consistent view of functions, even if object attributes
+ create new functions every time they are accessed.
+ """
+
+ def __init__(self, root):
+ super(_AugmentedGraphView, self).__init__(root)
+ # Object -> (name -> dep)
+ self._extra_dependencies = object_identity.ObjectIdentityDictionary()
+ self._functions = object_identity.ObjectIdentityDictionary()
+
+ def add_object(self, parent_node, name_in_parent, subgraph_root):
+ """Attach an object to `parent_node`, overriding any existing dependency."""
+ self._extra_dependencies.setdefault(
+ parent_node, {})[name_in_parent] = subgraph_root
+
+ def list_dependencies(self, obj):
+ """Overrides a parent method to include `add_object` objects."""
+ extra_dependencies = self._extra_dependencies.get(obj, {})
+ used_names = set()
+ for name, dep in super(_AugmentedGraphView, self).list_dependencies(obj):
+ used_names.add(name)
+ if name in extra_dependencies:
+ yield base.CheckpointableReference(name, extra_dependencies[name])
+ else:
+ yield base.CheckpointableReference(name, dep)
+ for name, dep in extra_dependencies.items():
+ if name in used_names:
+ continue
+ yield base.CheckpointableReference(name, dep)
+
+ def list_functions(self, obj):
+ obj_functions = self._functions.get(obj, None)
+ if obj_functions is None:
+ obj_functions = obj._list_functions_for_serialization() # pylint: disable=protected-access
+ self._functions[obj] = obj_functions
+ return obj_functions
+
+
class _SaveableView(object):
- """Provides a stable view over a checkpointable root.
+ """Provides a frozen view over a checkpointable root.
This class helps creating a single stable view over an object to save. The
saving code should access properties and functions via this class and not via
the original object as there are cases where an object construct their
checkpointable attributes and functions dynamically per call and will yield
different objects if invoked more than once.
+
+ Changes to the graph, for example adding objects, must happen in
+ `checkpoint_view` (an `_AugmentedGraphView`) before the `_SaveableView` is
+ constructed. Changes after the `_SaveableView` has been constructed will be
+ ignored.
"""
- def __init__(self, root):
- checkpointable_objects, node_ids, slot_variables = util.find_objects(root)
+ def __init__(self, checkpoint_view):
+ self.checkpoint_view = checkpoint_view
+ checkpointable_objects, node_ids, slot_variables = (
+ self.checkpoint_view.objects_ids_and_slot_variables())
self.nodes = checkpointable_objects
self.node_ids = node_ids
- self.captured_tensor_node_ids = util.ObjectIdentityDictionary()
+ self.captured_tensor_node_ids = object_identity.ObjectIdentityDictionary()
self.slot_variables = slot_variables
- self.functions = util.ObjectIdentityDictionary()
self.concrete_functions = []
# Also add `Function`s as nodes.
nodes_without_functions = list(self.nodes)
seen_function_names = set()
- for obj in nodes_without_functions:
- self.functions[obj] = self._list_functions(obj)
- for function in self.functions[obj].values():
+ for node in nodes_without_functions:
+ for function in checkpoint_view.list_functions(node).values():
if function not in self.node_ids:
self.node_ids[function] = len(self.nodes)
self.nodes.append(function)
- # Avoids recursing into functions to see if other functions are
- # assigned to attributes. This is sometimes true for concrete
- # functions but not helpful.
- self.functions[function] = {}
if isinstance(function, def_function.Function):
# Force listing the concrete functions for the side effects:
# - populate the cache for functions that have an input_signature
@@ -123,33 +175,16 @@
if isinstance(node, (def_function.Function, defun.ConcreteFunction,
_CapturedConstant)):
continue
- for child in node._checkpoint_dependencies: # pylint: disable=protected-access
+ for child in self.checkpoint_view.list_dependencies(node):
child_proto = object_proto.children.add()
child_proto.node_id = self.node_ids[child.ref]
child_proto.local_name = child.name
- for local_name, ref_function in self.functions[node].items():
+ for local_name, ref_function in (
+ self.checkpoint_view.list_functions(node).items()):
child_proto = object_proto.children.add()
child_proto.node_id = self.node_ids[ref_function]
child_proto.local_name = local_name
- def _list_functions(self, checkpointable_object):
- """Return a dict of `Function`s of a checkpointable."""
- functions = dict()
- attribute_extractor, attribute_getter = (
- revived_types.get_attribute_extractors(checkpointable_object))
- for attribute_name in attribute_extractor(checkpointable_object):
- try:
- attribute_value = attribute_getter(
- checkpointable_object, attribute_name, None)
- except Exception: # pylint: disable=broad-except
- # We really don't want to throw an exception just because some object's
- # attribute accessor is broken.
- attribute_value = None
- if isinstance(attribute_value, (def_function.Function,
- defun.ConcreteFunction)):
- functions[attribute_name] = attribute_value
- return functions
-
def map_resources(self):
"""Makes new resource handle ops corresponding to existing resource tensors.
@@ -171,7 +206,7 @@
assert not context.executing_eagerly()
# TODO(allenl): Handle MirroredVariables and other types of variables which
# may need special casing.
- object_map = util.ObjectIdentityDictionary()
+ object_map = object_identity.ObjectIdentityDictionary()
resource_map = {}
asset_info = _AssetInfo(
asset_defs=[],
@@ -210,108 +245,6 @@
return object_map, resource_map, asset_info
-def _get_signature(function):
- if (isinstance(function, (defun.Function, def_function.Function)) and
- function._input_signature is not None): # pylint: disable=protected-access
- function = function.get_concrete_function()
- if not isinstance(function, defun.ConcreteFunction):
- return None
- return function
-
-
-def _valid_signature(concrete_function):
- """Returns whether concrete function can be converted to a signature."""
- if not concrete_function.outputs:
- # Functions without outputs don't make sense as signatures. We just don't
- # have any way to run an Operation with no outputs as a SignatureDef in the
- # 1.x style.
- return False
- try:
- _normalize_outputs(concrete_function.structured_outputs, "unused", "unused")
- except ValueError:
- return False
- return True
-
-
-def _find_function_to_export(saveable_view):
- """Function to export, None if no suitable function was found."""
- # If the user did not specify signatures, check the root object for a function
- # that can be made into a signature.
- functions = saveable_view.functions[saveable_view.root]
- signature = functions.get(DEFAULT_SIGNATURE_ATTR, None)
- if signature is not None:
- return signature
-
- # TODO(andresp): Discuss removing this behaviour. It can lead to WTFs when a
- # user decides to annotate more functions with tf.function and suddenly
- # serving that model way later in the process stops working.
- if len(functions) == 1:
- single_function = list(functions.values())[0]
- signature = _get_signature(single_function)
- if signature and _valid_signature(signature):
- return signature
- return None
-
-
-def _canonicalize_signatures(signatures):
- """Converts `signatures` into a dictionary of concrete functions."""
- if signatures is None:
- return {}
- if not isinstance(signatures, collections.Mapping):
- signatures = {
- signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signatures}
- concrete_signatures = {}
- for signature_key, function in signatures.items():
- signature_function = _get_signature(function)
- if signature_function is None:
- raise ValueError(
- ("Expected a TensorFlow function to generate a signature for, but "
- "got {}. Only `tf.functions` with an input signature or "
- "concrete functions can be used as a signature.").format(function))
- concrete_signatures[signature_key] = signature_function
- return concrete_signatures
-
-
-def _is_flat(sequence):
- sequence_flat = nest.flatten(sequence)
- try:
- nest.assert_same_structure(sequence_flat, sequence)
- return True
- except ValueError:
- return False
- except TypeError:
- return False
-
-
-def _normalize_outputs(outputs, function_name, signature_key):
- """Construct an output dictionary from unnormalized function outputs."""
- if isinstance(outputs, collections.Mapping):
- for key, value in outputs.items():
- if not isinstance(value, ops.Tensor):
- raise ValueError(
- ("Got a dictionary containing non-Tensor value {} for key {} "
- "in the output of the function {} used to generate a SavedModel "
- "signature. Dictionaries outputs for functions used as signatures "
- "should have one Tensor output per string key.")
- .format(value, key, compat.as_str_any(function_name)))
- return outputs
- else:
- original_outputs = outputs
- if not isinstance(outputs, collections.Sequence):
- outputs = [outputs]
- if not _is_flat(outputs):
- raise ValueError(
- ("Got non-flat outputs '{}' from '{}' for SavedModel "
- "signature '{}'. Signatures have one Tensor per output, so "
- "to have predictable names Python functions used to generate "
- "these signatures should avoid outputting Tensors in nested "
- "structures.")
- .format(original_outputs, function_name, signature_key))
- return {("output_{}".format(output_index)): output
- for output_index, output
- in enumerate(outputs)}
-
-
def _tensor_dict_to_tensorinfo(tensor_dict):
return {key: utils_impl.build_tensor_info_internal(value)
for key, value in tensor_dict.items()}
@@ -441,8 +374,8 @@
Args:
signature_functions: A dictionary mapping string keys to concrete TensorFlow
- functions (e.g. from `_canonicalize_signatures`) which will be used to
- generate SignatureDefs.
+ functions (e.g. from `signature_serialization.canonicalize_signatures`)
+ which will be used to generate SignatureDefs.
resource_map: A dictionary mapping from resource tensors in the eager
context to resource tensors in the Graph being exported. This dictionary
is used to re-bind resources captured by functions to tensors which will
@@ -473,10 +406,8 @@
mapped_inputs, exterior_argument_placeholders = (
_map_function_arguments_to_created_inputs(
argument_inputs, signature_key, function.name))
- outputs = _normalize_outputs(
- _call_function_with_mapped_captures(
- function, mapped_inputs, resource_map),
- function.name, signature_key)
+ outputs = _call_function_with_mapped_captures(
+ function, mapped_inputs, resource_map)
signatures[signature_key] = signature_def_utils.build_signature_def(
_tensor_dict_to_tensorinfo(exterior_argument_placeholders),
_tensor_dict_to_tensorinfo(outputs),
@@ -541,8 +472,7 @@
resource_map[original_variable.handle] = asset_variable.handle
-def _fill_meta_graph_def(meta_graph_def, saveable_view, signature_functions,
- object_saver):
+def _fill_meta_graph_def(meta_graph_def, saveable_view, signature_functions):
"""Generates a MetaGraph which calls `signature_functions`.
Args:
@@ -550,7 +480,6 @@
saveable_view: The _SaveableView being exported.
signature_functions: A dictionary mapping signature keys to concrete
functions containing signatures to add to the MetaGraph.
- object_saver: A CheckpointableSaver to add to the MetaGraph.
Returns:
An _AssetInfo, which contains information to help creating the SavedModel.
@@ -590,7 +519,9 @@
# gathering from the eager context so Optimizers save the right set of
# variables, but want any operations associated with the save/restore to be in
# the exported graph (thus the `to_graph` argument).
- saver = object_saver.freeze(object_map=object_map, to_graph=exported_graph)
+ saver = functional_saver.Saver(
+ saveable_view.checkpoint_view.frozen_saveable_objects(
+ object_map=object_map, to_graph=exported_graph))
with exported_graph.as_default():
signatures = _generate_signatures(signature_functions, resource_map)
@@ -712,6 +643,10 @@
which case outputs will be numbered, or a dictionary mapping string keys to
`Tensor`, in which case the keys will be used to name outputs.
+ Signatures are available in objects returned by `tf.saved_model.load` as a
+ `.signatures` attribute. This is a reserved attribute: `tf.saved_model.save`
+ on an object with a custom `.signatures` attribute will raise an exception.
+
Since `tf.keras.Model` objects are also Checkpointable, this function can be
used to export Keras models. For example, exporting with a signature
specified:
@@ -836,26 +771,33 @@
raise ValueError(
"Expected a Checkpointable object for export, got {}.".format(obj))
- # Use _SaveableView to provide a stable listing of properties and functions.
+ checkpoint_graph_view = _AugmentedGraphView(obj)
+ if signatures is None:
+ signatures = signature_serialization.find_function_to_export(
+ checkpoint_graph_view)
+
+ signatures = signature_serialization.canonicalize_signatures(signatures)
+ signature_map = signature_serialization.create_signature_map(
+ signatures, checkpoint_graph_view)
+ checkpoint_graph_view.add_object(
+ parent_node=checkpoint_graph_view.root,
+ name_in_parent=signature_serialization.SIGNATURE_ATTRIBUTE_NAME,
+ subgraph_root=signature_map)
+
+ # Use _SaveableView to provide a frozen listing of properties and functions.
# Note we run this twice since, while constructing the view the first time
# there can be side effects of creating variables.
- _ = _SaveableView(obj)
- saveable_view = _SaveableView(obj)
-
- if signatures is None:
- signatures = _find_function_to_export(saveable_view)
-
- signatures = _canonicalize_signatures(signatures)
+ _ = _SaveableView(checkpoint_graph_view)
+ saveable_view = _SaveableView(checkpoint_graph_view)
# TODO(allenl): Factor out some subset of SavedModelBuilder which is 2.x
# compatible (no sessions) and share it with this export API rather than
# making a SavedModel proto and writing it directly.
saved_model = saved_model_pb2.SavedModel()
meta_graph_def = saved_model.meta_graphs.add()
- # TODO(andresp): Should this be using saveable_view?
- object_saver = util.CheckpointableSaver(obj)
+ object_saver = util.CheckpointableSaver(checkpoint_graph_view)
asset_info, exported_graph = _fill_meta_graph_def(
- meta_graph_def, saveable_view, signatures, object_saver)
+ meta_graph_def, saveable_view, signatures)
saved_model.saved_model_schema_version = (
constants.SAVED_MODEL_SCHEMA_VERSION)
# So far we've just been generating protocol buffers with no I/O. Now we write
diff --git a/tensorflow/python/saved_model/save_test.py b/tensorflow/python/saved_model/save_test.py
index cbca51a..e1b2f5d 100644
--- a/tensorflow/python/saved_model/save_test.py
+++ b/tensorflow/python/saved_model/save_test.py
@@ -303,6 +303,14 @@
self.assertNotIn("T", complex_node.attr)
self.assertNotIn("Tout", complex_node.attr)
+ def test_signature_attribute_reserved(self):
+ root = util.Checkpoint(signatures=variables.Variable(1.))
+ save_dir = os.path.join(self.get_temp_dir(), "saved_model")
+ with self.assertRaisesRegexp(ValueError, "del obj.signatures"):
+ save.save(root, save_dir)
+ del root.signatures
+ save.save(root, save_dir)
+
class AssetTests(test.TestCase):
diff --git a/tensorflow/python/saved_model/signature_constants.py b/tensorflow/python/saved_model/signature_constants.py
index 0efe176..8047d0d 100644
--- a/tensorflow/python/saved_model/signature_constants.py
+++ b/tensorflow/python/saved_model/signature_constants.py
@@ -136,6 +136,10 @@
################################################################################
# Train/Eval API constants.
# Not exported while export_all_saved_models is experimental.
+DEFAULT_TRAIN_SIGNATURE_DEF_KEY = "train"
+# TODO(b/123998850): Change default signature key to "test" after making sure
+# that TFMA use cases won't break.
+DEFAULT_EVAL_SIGNATURE_DEF_KEY = "eval"
SUPERVISED_TRAIN_METHOD_NAME = "tensorflow/supervised/training"
diff --git a/tensorflow/python/saved_model/signature_serialization.py b/tensorflow/python/saved_model/signature_serialization.py
new file mode 100644
index 0000000..19f15f188
--- /dev/null
+++ b/tensorflow/python/saved_model/signature_serialization.py
@@ -0,0 +1,244 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Helpers for working with signatures in tf.saved_model.save."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from tensorflow.python.eager import def_function
+from tensorflow.python.eager import function as defun
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_spec
+from tensorflow.python.saved_model import revived_types
+from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.training.checkpointable import base
+from tensorflow.python.util import compat
+from tensorflow.python.util import nest
+
+
+DEFAULT_SIGNATURE_ATTR = "_default_save_signature"
+SIGNATURE_ATTRIBUTE_NAME = "signatures"
+
+
+def _get_signature(function):
+ if (isinstance(function, (defun.Function, def_function.Function)) and
+ function._input_signature is not None): # pylint: disable=protected-access
+ function = function.get_concrete_function()
+ if not isinstance(function, defun.ConcreteFunction):
+ return None
+ return function
+
+
+def _valid_signature(concrete_function):
+ """Returns whether concrete function can be converted to a signature."""
+ if not concrete_function.outputs:
+ # Functions without outputs don't make sense as signatures. We just don't
+ # have any way to run an Operation with no outputs as a SignatureDef in the
+ # 1.x style.
+ return False
+ try:
+ _normalize_outputs(concrete_function.structured_outputs, "unused", "unused")
+ except ValueError:
+ return False
+ return True
+
+
+def find_function_to_export(saveable_view):
+ """Function to export, None if no suitable function was found."""
+ # If the user did not specify signatures, check the root object for a function
+ # that can be made into a signature.
+ functions = saveable_view.list_functions(saveable_view.root)
+ signature = functions.get(DEFAULT_SIGNATURE_ATTR, None)
+ if signature is not None:
+ return signature
+
+ # TODO(andresp): Discuss removing this behaviour. It can lead to WTFs when a
+ # user decides to annotate more functions with tf.function and suddenly
+ # serving that model way later in the process stops working.
+ possible_signatures = []
+ for function in functions.values():
+ concrete = _get_signature(function)
+ if concrete is not None and _valid_signature(concrete):
+ possible_signatures.append(concrete)
+ if len(possible_signatures) == 1:
+ single_function = possible_signatures[0]
+ signature = _get_signature(single_function)
+ if signature and _valid_signature(signature):
+ return signature
+ return None
+
+
+def canonicalize_signatures(signatures):
+ """Converts `signatures` into a dictionary of concrete functions."""
+ if signatures is None:
+ return {}
+ if not isinstance(signatures, collections.Mapping):
+ signatures = {
+ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signatures}
+ concrete_signatures = {}
+ for signature_key, function in signatures.items():
+ signature_function = _get_signature(function)
+ if signature_function is None:
+ raise ValueError(
+ ("Expected a TensorFlow function to generate a signature for, but "
+ "got {}. Only `tf.functions` with an input signature or "
+ "concrete functions can be used as a signature.").format(function))
+
+ # Re-wrap the function so that it only takes keyword arguments and it
+ # returns a dictionary of Tensors. This matches the format of 1.x-style
+ # signatures.
+ # pylint: disable=cell-var-from-loop
+ @def_function.function
+ def signature_wrapper(**kwargs):
+ structured_outputs = signature_function(**kwargs)
+ return _normalize_outputs(
+ structured_outputs, signature_function.name, signature_key)
+ # TODO(b/123902469): Use ConcreteFunction.structured_inputs once their names
+ # always match keyword arguments.
+ tensor_spec_signature = {}
+ for keyword, tensor in zip(
+ signature_function._arg_keywords, # pylint: disable=protected-access
+ signature_function.inputs):
+ keyword = compat.as_str(keyword)
+ tensor_spec_signature[keyword] = tensor_spec.TensorSpec.from_tensor(
+ tensor, name=keyword)
+ concrete_signatures[signature_key] = (
+ signature_wrapper.get_concrete_function(**tensor_spec_signature))
+ # pylint: enable=cell-var-from-loop
+ return concrete_signatures
+
+
+def _is_flat(sequence):
+ sequence_flat = nest.flatten(sequence)
+ try:
+ nest.assert_same_structure(sequence_flat, sequence)
+ return True
+ except ValueError:
+ return False
+ except TypeError:
+ return False
+
+
+def _normalize_outputs(outputs, function_name, signature_key):
+ """Construct an output dictionary from unnormalized function outputs."""
+ if isinstance(outputs, collections.Mapping):
+ for key, value in outputs.items():
+ if not isinstance(value, ops.Tensor):
+ raise ValueError(
+ ("Got a dictionary containing non-Tensor value {} for key {} "
+ "in the output of the function {} used to generate a SavedModel "
+ "signature. Dictionaries outputs for functions used as signatures "
+ "should have one Tensor output per string key.")
+ .format(value, key, compat.as_str_any(function_name)))
+ return outputs
+ else:
+ original_outputs = outputs
+ if not isinstance(outputs, collections.Sequence):
+ outputs = [outputs]
+ if not _is_flat(outputs):
+ raise ValueError(
+ ("Got non-flat outputs '{}' from '{}' for SavedModel "
+ "signature '{}'. Signatures have one Tensor per output, so "
+ "to have predictable names Python functions used to generate "
+ "these signatures should avoid outputting Tensors in nested "
+ "structures.")
+ .format(original_outputs, function_name, signature_key))
+ return {("output_{}".format(output_index)): output
+ for output_index, output
+ in enumerate(outputs)}
+
+
+# _SignatureMap is immutable to ensure that users do not expect changes to be
+# reflected in the SavedModel. Using public APIs, tf.saved_model.load() is the
+# only way to create a _SignatureMap and there is no way to modify it. So we can
+# safely ignore/overwrite ".signatures" attributes attached to objects being
+# saved if they contain a _SignatureMap. A ".signatures" attribute containing
+# any other type (e.g. a regular dict) will raise an exception asking the user
+# to first "del obj.signatures" if they want it overwritten.
+class _SignatureMap(collections.Mapping, base.Checkpointable):
+ """A collection of SavedModel signatures."""
+
+ def __init__(self):
+ self._signatures = {}
+
+ def _add_signature(self, name, concrete_function):
+ """Adds a signature to the _SignatureMap."""
+ # Ideally this object would be immutable, but restore is streaming so we do
+ # need a private API for adding new signatures to an existing object.
+ self._signatures[name] = concrete_function
+
+ def __getitem__(self, key):
+ return self._signatures[key]
+
+ def __iter__(self):
+ return iter(self._signatures)
+
+ def __len__(self):
+ return len(self._signatures)
+
+ def __repr__(self):
+ return "_SignatureMap({})".format(self._signatures)
+
+ def _list_functions_for_serialization(self):
+ return {
+ key: value for key, value in self.items()
+ if isinstance(value, (def_function.Function, defun.ConcreteFunction))
+ }
+
+
+revived_types.register_revived_type(
+ "signature_map",
+ lambda obj: isinstance(obj, _SignatureMap),
+ versions=[revived_types.VersionedTypeRegistration(
+ # Standard dependencies are enough to reconstruct the checkpointable
+ # items in dictionaries, so we don't need to save any extra information.
+ object_factory=lambda proto: _SignatureMap(),
+ version=1,
+ min_producer_version=1,
+ min_consumer_version=1,
+ setter=_SignatureMap._add_signature # pylint: disable=protected-access
+ )])
+
+
+def create_signature_map(signatures, saveable_view):
+ """Performs sanity checks and creates an object containing `signatures`."""
+ for name, dep in saveable_view.list_dependencies(
+ saveable_view.root):
+ if name == SIGNATURE_ATTRIBUTE_NAME:
+ if not isinstance(dep, _SignatureMap):
+ raise ValueError(
+ ("Exporting an object {} which has an attribute named "
+ "'{signatures}'. This is a reserved attribute used to store "
+ "SavedModel signatures in objects which come from "
+ "`tf.saved_model.load`. Delete this attribute "
+ "(e.g. 'del obj.{signatures}') before saving if this shadowing is "
+ "acceptable.").format(
+ saveable_view.root,
+ signatures=SIGNATURE_ATTRIBUTE_NAME))
+ break
+ signature_map = _SignatureMap()
+ for name, func in signatures.items():
+ # This true of any signature that came from canonicalize_signatures. Here as
+ # a sanity check on saving; crashing on load (e.g. in _add_signature) would
+ # be more problematic in case future export changes violated these
+ # assertions.
+ assert isinstance(func, defun.ConcreteFunction)
+ assert isinstance(func.structured_outputs, collections.Mapping)
+ assert 0 == func._num_positional_args # pylint: disable=protected-access
+ signature_map._add_signature(name, func) # pylint: disable=protected-access
+ return signature_map
diff --git a/tensorflow/python/tools/BUILD b/tensorflow/python/tools/BUILD
index f1a911e..b8ad31e 100644
--- a/tensorflow/python/tools/BUILD
+++ b/tensorflow/python/tools/BUILD
@@ -14,23 +14,20 @@
py_library(
name = "tools_pip",
deps = [
- ":freeze_graph",
- ":import_pb_to_tensorboard",
- ":inspect_checkpoint",
- ":optimize_for_inference",
- ":print_selective_registration_header",
- ":saved_model_cli",
+ ":freeze_graph_lib",
+ ":import_pb_to_tensorboard_lib",
+ ":inspect_checkpoint_lib",
+ ":optimize_for_inference_lib",
+ ":print_selective_registration_header_lib",
+ ":saved_model_cli_lib",
":saved_model_utils",
- ":strip_unused",
+ ":strip_unused_lib",
# The following py_library are needed because
# py_binary may not depend on them when --define=no_tensorflow_py_deps=true
# is specified. See https://github.com/tensorflow/tensorflow/issues/22390
- ":freeze_graph_lib",
- ":optimize_for_inference_lib",
":selective_registration_header_lib",
- ":strip_unused_lib",
# Include the TF upgrade script to users can run it directly after install TF
- "//tensorflow/tools/compatibility:tf_upgrade_v2",
+ "//tensorflow/tools/compatibility:tf_upgrade_v2_lib",
],
)
@@ -86,6 +83,13 @@
name = "import_pb_to_tensorboard",
srcs = ["import_pb_to_tensorboard.py"],
srcs_version = "PY2AND3",
+ deps = [":import_pb_to_tensorboard_lib"],
+)
+
+py_library(
+ name = "import_pb_to_tensorboard_lib",
+ srcs = ["import_pb_to_tensorboard.py"],
+ srcs_version = "PY2AND3",
deps = [
"//tensorflow/core:protos_all_py",
"//tensorflow/python",
@@ -103,7 +107,7 @@
srcs = ["freeze_graph_test.py"],
srcs_version = "PY2AND3",
deps = [
- ":freeze_graph",
+ ":freeze_graph_lib",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client",
"//tensorflow/python:client_testlib",
@@ -120,6 +124,13 @@
name = "inspect_checkpoint",
srcs = ["inspect_checkpoint.py"],
srcs_version = "PY2AND3",
+ deps = [":inspect_checkpoint_lib"],
+)
+
+py_library(
+ name = "inspect_checkpoint_lib",
+ srcs = ["inspect_checkpoint.py"],
+ srcs_version = "PY2AND3",
deps = [
"//tensorflow/python", # TODO(b/34059704): remove when fixed
"//tensorflow/python:platform",
@@ -240,6 +251,14 @@
srcs = ["print_selective_registration_header.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
+ deps = [":print_selective_registration_header_lib"],
+)
+
+py_library(
+ name = "print_selective_registration_header_lib",
+ srcs = ["print_selective_registration_header.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
deps = [
":selective_registration_header_lib",
"//tensorflow/python:platform",
@@ -261,6 +280,13 @@
name = "saved_model_cli",
srcs = ["saved_model_cli.py"],
srcs_version = "PY2AND3",
+ deps = [":saved_model_cli_lib"],
+)
+
+py_library(
+ name = "saved_model_cli_lib",
+ srcs = ["saved_model_cli.py"],
+ srcs_version = "PY2AND3",
deps = [
":saved_model_utils",
"//tensorflow/python",
@@ -280,7 +306,7 @@
"no-internal-py3",
],
deps = [
- ":saved_model_cli",
+ ":saved_model_cli_lib",
"//tensorflow/core:protos_all_py",
],
)
diff --git a/tensorflow/python/training/checkpointable/BUILD b/tensorflow/python/training/checkpointable/BUILD
index a394627..e1f58a9 100644
--- a/tensorflow/python/training/checkpointable/BUILD
+++ b/tensorflow/python/training/checkpointable/BUILD
@@ -29,6 +29,7 @@
"//tensorflow/python:util",
"//tensorflow/python/eager:context",
"//tensorflow/python/training/saving:saveable_object",
+ "@six_archive//:six",
],
)
@@ -95,23 +96,48 @@
)
py_library(
+ name = "object_identity",
+ srcs = ["object_identity.py"],
+ srcs_version = "PY2AND3",
+)
+
+py_library(
+ name = "graph_view",
+ srcs = ["graph_view.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":base",
+ ":object_identity",
+ ":tracking",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/training/saving:saveable_object",
+ "//tensorflow/python/training/saving:saveable_object_util",
+ ],
+)
+
+py_library(
name = "util",
srcs = ["util.py"],
srcs_version = "PY2AND3",
deps = [
":base",
":data_structures",
+ ":graph_view",
+ ":object_identity",
":tracking",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:checkpoint_management",
"//tensorflow/python:constant_op",
- "//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:init_ops",
"//tensorflow/python:io_ops_gen",
+ "//tensorflow/python:lib",
"//tensorflow/python:pywrap_tensorflow",
"//tensorflow/python:saver",
"//tensorflow/python:session",
@@ -122,7 +148,6 @@
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function",
"//tensorflow/python/training/saving:functional_saver",
- "//tensorflow/python/training/saving:saveable_object",
"//tensorflow/python/training/saving:saveable_object_util",
],
)
@@ -132,10 +157,12 @@
srcs = ["util_test.py"],
additional_deps = [
":base",
+ ":graph_view",
":tracking",
":util",
"@absl_py//absl/testing:parameterized",
"@six_archive//:six",
+ "//tensorflow/python/keras/optimizer_v2",
"//tensorflow/python:checkpoint_management",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
@@ -149,8 +176,8 @@
"//tensorflow/python:session",
"//tensorflow/python:state_ops",
"//tensorflow/python:template",
- "//tensorflow/python:training",
"//tensorflow/python:training_util",
+ "//tensorflow/python:training",
"//tensorflow/python:variable_scope",
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:context",
@@ -158,6 +185,7 @@
"//tensorflow/python/eager:test",
"//tensorflow/python/keras:engine",
"//tensorflow/python/keras:layers",
+ "//tensorflow/python:variables",
],
tags = ["notsan"], # b/74395663
)
@@ -190,6 +218,7 @@
srcs = ["util_with_v1_optimizers_test.py"],
additional_deps = [
":base",
+ ":graph_view",
":tracking",
":util",
"@absl_py//absl/testing:parameterized",
diff --git a/tensorflow/python/training/checkpointable/base.py b/tensorflow/python/training/checkpointable/base.py
index 9fb251e..0d659ce 100644
--- a/tensorflow/python/training/checkpointable/base.py
+++ b/tensorflow/python/training/checkpointable/base.py
@@ -187,14 +187,14 @@
return control_flow_ops.no_op()
-class _CheckpointPosition(object):
- """Indicates a position within a `_Checkpoint`."""
+class CheckpointPosition(object):
+ """Indicates a position within a `_CheckpointRestoreCoordinator`."""
def __init__(self, checkpoint, proto_id):
"""Specify an object within a checkpoint.
Args:
- checkpoint: A _Checkpoint object.
+ checkpoint: A _CheckpointRestoreCoordinator object.
proto_id: The index of this object in CheckpointableObjectGraph.nodes.
"""
self._checkpoint = checkpoint
@@ -229,7 +229,7 @@
for deferred_slot_restoration in (
checkpoint.deferred_slot_restorations.pop(self._proto_id, ())):
checkpointable._create_or_restore_slot_variable( # pylint: disable=protected-access
- slot_variable_position=_CheckpointPosition(
+ slot_variable_position=CheckpointPosition(
checkpoint=checkpoint,
proto_id=deferred_slot_restoration.slot_variable_id),
variable=deferred_slot_restoration.original_variable,
@@ -249,7 +249,7 @@
slot_name=slot_restoration.slot_name))
else:
optimizer_object._create_or_restore_slot_variable( # pylint: disable=protected-access
- slot_variable_position=_CheckpointPosition(
+ slot_variable_position=CheckpointPosition(
checkpoint=checkpoint,
proto_id=slot_restoration.slot_variable_id),
variable=checkpointable,
@@ -325,14 +325,15 @@
# the SaveableObject itself has been cached. If not, we'll make it, and
# either way we'll extract new ops from it (or if it has Python state to
# restore, we'll run that).
- if self._checkpoint.saveable_object_cache is None:
+ saveables_cache = self._checkpoint.graph_view.saveables_cache
+ if saveables_cache is None:
# No SaveableObject caching when executing eagerly.
saveable = None
else:
# If we've already created and cached a SaveableObject for this
# attribute, we can re-use it to avoid re-creating some ops when graph
# building.
- saveable_list = self._checkpoint.saveable_object_cache.get(
+ saveable_list = saveables_cache.get(
self.checkpointable, {}).get(serialized_tensor.name, (None,))
if len(saveable_list) == 1:
# Almost every attribute will have exactly one SaveableObject.
@@ -347,7 +348,7 @@
# the SaveableObject.
if serialized_tensor.checkpoint_key not in saveable.name:
saveable = None
- del self._checkpoint.saveable_object_cache[self.checkpointable]
+ del saveables_cache[self.checkpointable]
break
if saveable is None:
# If there was no cached SaveableObject, we should check if the Python
@@ -366,8 +367,8 @@
saveable = saveable_factory(name=serialized_tensor.checkpoint_key)
else:
saveable = saveable_factory
- if self._checkpoint.saveable_object_cache is not None:
- self._checkpoint.saveable_object_cache.setdefault(
+ if saveables_cache is not None:
+ saveables_cache.setdefault(
self.checkpointable, {})[serialized_tensor.name] = [saveable]
if isinstance(saveable, PythonStateSaveable):
python_saveables.append(saveable)
@@ -491,7 +492,7 @@
# Maps names -> Checkpointable objects
self._unconditional_dependency_names = {}
# Restorations for other Checkpointable objects on which this object may
- # eventually depend. Maps local name -> _CheckpointPosition list. Optimizers
+ # eventually depend. Maps local name -> CheckpointPosition list. Optimizers
# tack on conditional dependencies, and so need separate management of
# deferred dependencies too.
self._unconditional_deferred_dependencies = {}
@@ -545,7 +546,7 @@
management of deferred dependencies too).
Returns:
- A dictionary mapping from local name to a list of _CheckpointPosition
+ A dictionary mapping from local name to a list of CheckpointPosition
objects.
"""
return self._unconditional_deferred_dependencies
@@ -791,7 +792,7 @@
else:
restore_ops = ()
for child in checkpoint_position.object_proto.children:
- child_position = _CheckpointPosition(
+ child_position = CheckpointPosition(
checkpoint=checkpoint,
proto_id=child.node_id)
local_object = self._lookup_dependency(child.local_name)
@@ -861,3 +862,16 @@
return {OBJECT_CONFIG_JSON_KEY: functools.partial(
PythonStringStateSaveable,
state_callback=_state_callback)}
+
+ def _list_functions_for_serialization(self):
+ """Lists the functions of this checkpointable to serialize.
+
+ Internal sub-classes can override this with specific logic. E.g.
+ `AutoCheckpointable` provides an implementation that returns the `attr`
+ that return functions.
+
+ Returns:
+ A dictionary mapping attribute names to `Function` or
+ `ConcreteFunction`.
+ """
+ return dict()
diff --git a/tensorflow/python/training/checkpointable/data_structures.py b/tensorflow/python/training/checkpointable/data_structures.py
index c86846f..ae3ab3f 100644
--- a/tensorflow/python/training/checkpointable/data_structures.py
+++ b/tensorflow/python/training/checkpointable/data_structures.py
@@ -24,6 +24,8 @@
import six
+from tensorflow.python.eager import def_function
+from tensorflow.python.eager import function as defun
from tensorflow.python.ops import variables
from tensorflow.python.saved_model import revived_types
from tensorflow.python.training.checkpointable import base
@@ -525,6 +527,12 @@
def __repr__(self):
return "ListWrapper(%s)" % (repr(self._storage),)
+ def _list_functions_for_serialization(self):
+ return {
+ str(key): value for key, value in enumerate(self)
+ if _is_function(value)
+ }
+
class Mapping(CheckpointableDataStructure, collections.Mapping):
"""An append-only checkpointable mapping data structure with string keys.
@@ -793,6 +801,16 @@
for key, value in dict(*args, **kwargs).items():
self[key] = value
+ def _list_functions_for_serialization(self):
+ return {
+ key: value for key, value in self.items()
+ if _is_function(value)
+ }
+
+
+def _is_function(x):
+ return isinstance(x, (def_function.Function, defun.ConcreteFunction))
+
revived_types.register_revived_type(
"checkpointable_dict_wrapper",
lambda obj: isinstance(obj, _DictWrapper),
@@ -803,9 +821,7 @@
version=1,
min_producer_version=1,
min_consumer_version=1,
- setter=operator.setitem,
- getter=_DictWrapper.get,
- attribute_extractor=lambda obj: obj.keys())])
+ setter=operator.setitem)])
def _set_list_item(list_object, index_string, value):
@@ -815,13 +831,6 @@
list_object[item_index] = value
-def _list_getter(obj, item, default=None):
- index = int(item)
- if index < len(obj):
- return obj[index]
- return default
-
-
revived_types.register_revived_type(
"checkpointable_list_wrapper",
lambda obj: isinstance(obj, _ListWrapper),
@@ -830,6 +839,4 @@
version=1,
min_producer_version=1,
min_consumer_version=1,
- setter=_set_list_item,
- getter=_list_getter,
- attribute_extractor=lambda obj: [str(i) for i in range(len(obj))])])
+ setter=_set_list_item)])
diff --git a/tensorflow/python/training/checkpointable/graph_view.py b/tensorflow/python/training/checkpointable/graph_view.py
new file mode 100644
index 0000000..46c6289
--- /dev/null
+++ b/tensorflow/python/training/checkpointable/graph_view.py
@@ -0,0 +1,431 @@
+"""Manages a graph of Checkpointable objects."""
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import weakref
+
+from tensorflow.core.protobuf import checkpointable_object_graph_pb2
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.training import optimizer as optimizer_v1
+from tensorflow.python.training.checkpointable import base
+from tensorflow.python.training.checkpointable import object_identity
+from tensorflow.python.training.checkpointable import tracking
+from tensorflow.python.training.saving import saveable_object as saveable_object_lib
+from tensorflow.python.training.saving import saveable_object_util
+
+
+_ESCAPE_CHAR = "." # For avoiding conflicts with user-specified names.
+
+# Keyword for identifying that the next bit of a checkpoint variable name is a
+# slot name. Checkpoint names for slot variables look like:
+#
+# <path to variable>/<_OPTIMIZER_SLOTS_NAME>/<path to optimizer>/<slot name>
+#
+# Where <path to variable> is a full path from the checkpoint root to the
+# variable being slotted for.
+_OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"
+# Keyword for separating the path to an object from the name of an
+# attribute in checkpoint names. Used like:
+# <path to variable>/<_OBJECT_ATTRIBUTES_NAME>/<name of attribute>
+_OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"
+
+
+def _escape_local_name(name):
+ # We need to support slashes in local names for compatibility, since this
+ # naming scheme is being patched in to things like Layer.add_variable where
+ # slashes were previously accepted. We also want to use slashes to indicate
+ # edges traversed to reach the variable, so we escape forward slashes in
+ # names.
+ return (name.replace(_ESCAPE_CHAR, _ESCAPE_CHAR + _ESCAPE_CHAR)
+ .replace(r"/", _ESCAPE_CHAR + "S"))
+
+
+def _object_prefix_from_path(path_to_root):
+ return "/".join(
+ (_escape_local_name(checkpointable.name)
+ for checkpointable in path_to_root))
+
+
+def _slot_variable_naming_for_optimizer(optimizer_path):
+ """Make a function for naming slot variables in an optimizer."""
+ # Name slot variables:
+ #
+ # <variable name>/<_OPTIMIZER_SLOTS_NAME>/<optimizer path>/<slot name>
+ #
+ # where <variable name> is exactly the checkpoint name used for the original
+ # variable, including the path from the checkpoint root and the local name in
+ # the object which owns it. Note that we only save slot variables if the
+ # variable it's slotting for is also being saved.
+
+ optimizer_identifier = "/%s/%s/" % (_OPTIMIZER_SLOTS_NAME, optimizer_path)
+
+ def _name_slot_variable(variable_path, slot_name):
+ """With an optimizer specified, name a slot variable."""
+ return (variable_path
+ + optimizer_identifier
+ + _escape_local_name(slot_name))
+
+ return _name_slot_variable
+
+
+def _serialize_slot_variables(checkpointable_objects, node_ids, object_names):
+ """Gather and name slot variables."""
+ non_slot_objects = list(checkpointable_objects)
+ slot_variables = object_identity.ObjectIdentityDictionary()
+ for checkpointable in non_slot_objects:
+ if (isinstance(checkpointable, optimizer_v1.Optimizer)
+ # TODO(b/110718070): Fix Keras imports.
+ or hasattr(checkpointable, "_create_or_restore_slot_variable")):
+ naming_scheme = _slot_variable_naming_for_optimizer(
+ optimizer_path=object_names[checkpointable])
+ slot_names = checkpointable.get_slot_names()
+ for slot_name in slot_names:
+ for original_variable_node_id, original_variable in enumerate(
+ non_slot_objects):
+ try:
+ slot_variable = checkpointable.get_slot(
+ original_variable, slot_name)
+ except (AttributeError, KeyError):
+ slot_variable = None
+ if slot_variable is None:
+ continue
+ slot_variable._maybe_initialize_checkpointable() # pylint: disable=protected-access
+ if slot_variable._checkpoint_dependencies: # pylint: disable=protected-access
+ # TODO(allenl): Gather dependencies of slot variables.
+ raise NotImplementedError(
+ "Currently only variables with no dependencies can be saved as "
+ "slot variables. File a feature request if this limitation "
+ "bothers you.")
+ if slot_variable in node_ids:
+ raise NotImplementedError(
+ "A slot variable was re-used as a dependency of a "
+ "Checkpointable object. This is not currently allowed. File a "
+ "feature request if this limitation bothers you.")
+ checkpoint_name = naming_scheme(
+ variable_path=object_names[original_variable],
+ slot_name=slot_name)
+ object_names[slot_variable] = checkpoint_name
+ slot_variable_node_id = len(checkpointable_objects)
+ node_ids[slot_variable] = slot_variable_node_id
+ checkpointable_objects.append(slot_variable)
+ slot_variable_proto = (
+ checkpointable_object_graph_pb2.CheckpointableObjectGraph
+ .CheckpointableObject.SlotVariableReference(
+ slot_name=slot_name,
+ original_variable_node_id=original_variable_node_id,
+ slot_variable_node_id=slot_variable_node_id))
+ slot_variables.setdefault(checkpointable, []).append(
+ slot_variable_proto)
+ return slot_variables
+
+
+class ObjectGraphView(object):
+ """Gathers and serializes an object graph."""
+
+ def __init__(self, root, saveables_cache=None):
+ """Configure the graph view.
+
+ Args:
+ root: A `Checkpointable` object whose variables (including the variables
+ of dependencies, recursively) should be saved. May be a weak reference.
+ saveables_cache: A dictionary mapping `Checkpointable` objects ->
+ attribute names -> SaveableObjects, used to avoid re-creating
+ SaveableObjects when graph building.
+ """
+ self._root_ref = root
+ self._saveables_cache = saveables_cache
+
+ def list_dependencies(self, obj):
+ # pylint: disable=protected-access
+ obj._maybe_initialize_checkpointable()
+ return obj._checkpoint_dependencies
+ # pylint: enable=protected-access
+
+ @property
+ def saveables_cache(self):
+ """Maps Checkpointable objects -> attribute names -> list(SaveableObjects).
+
+ Used to avoid re-creating SaveableObjects when graph building. None when
+ executing eagerly.
+
+ Returns:
+ The cache (an object-identity dictionary), or None if caching is disabled.
+ """
+ return self._saveables_cache
+
+ @property
+ def root(self):
+ if isinstance(self._root_ref, weakref.ref):
+ derefed = self._root_ref()
+ assert derefed is not None
+ return derefed
+ else:
+ return self._root_ref
+
+ def _breadth_first_traversal(self):
+ """Find shortest paths to all dependencies of self.root."""
+ bfs_sorted = []
+ to_visit = collections.deque([self.root])
+ path_to_root = object_identity.ObjectIdentityDictionary()
+ path_to_root[self.root] = ()
+ while to_visit:
+ current_checkpointable = to_visit.popleft()
+ if isinstance(current_checkpointable, tracking.NotCheckpointable):
+ raise NotImplementedError(
+ ("The object %s does not support object-based saving. File a "
+ "feature request if this limitation bothers you. In the meantime, "
+ "you can remove the dependency on this object and save everything "
+ "else.")
+ % (current_checkpointable,))
+ bfs_sorted.append(current_checkpointable)
+ for name, dependency in self.list_dependencies(current_checkpointable):
+ if dependency not in path_to_root:
+ path_to_root[dependency] = (
+ path_to_root[current_checkpointable] + (
+ base.CheckpointableReference(name, dependency),))
+ to_visit.append(dependency)
+ return bfs_sorted, path_to_root
+
+ def _add_attributes_to_object_graph(
+ self, checkpointable_objects, object_graph_proto, node_ids, object_names,
+ object_map):
+ """Create SaveableObjects and corresponding SerializedTensor protos."""
+ named_saveable_objects = []
+ if self._saveables_cache is None:
+ # No SaveableObject caching. Either we're executing eagerly, or building a
+ # static save which is specialized to the current Python state.
+ feed_additions = None
+ else:
+ # If we are caching SaveableObjects, we need to build up a feed_dict with
+ # functions computing volatile Python state to be saved with the
+ # checkpoint.
+ feed_additions = {}
+ for checkpoint_id, (checkpointable, object_proto) in enumerate(
+ zip(checkpointable_objects, object_graph_proto.nodes)):
+ assert node_ids[checkpointable] == checkpoint_id
+ object_name = object_names[checkpointable]
+ if object_map is None:
+ object_to_save = checkpointable
+ else:
+ object_to_save = object_map.get(checkpointable, checkpointable)
+ if self._saveables_cache is not None:
+ cached_attributes = self._saveables_cache.setdefault(object_to_save, {})
+ else:
+ cached_attributes = None
+
+ for name, saveable_factory in (
+ object_to_save._gather_saveables_for_checkpoint().items()): # pylint: disable=protected-access
+ attribute = object_proto.attributes.add()
+ attribute.name = name
+ attribute.checkpoint_key = "%s/%s/%s" % (
+ object_name, _OBJECT_ATTRIBUTES_NAME, _escape_local_name(name))
+ if cached_attributes is None:
+ saveables = None
+ else:
+ saveables = cached_attributes.get(name, None)
+ if saveables is not None:
+ for saveable in saveables:
+ if attribute.checkpoint_key not in saveable.name:
+ # The checkpoint key for this SaveableObject is different. We
+ # need to re-create it.
+ saveables = None
+ del cached_attributes[name]
+ break
+ if saveables is None:
+ if callable(saveable_factory):
+ maybe_saveable = saveable_factory(name=attribute.checkpoint_key)
+ else:
+ maybe_saveable = saveable_factory
+ if isinstance(maybe_saveable, saveable_object_lib.SaveableObject):
+ saveables = (maybe_saveable,)
+ else:
+ # Figure out the name-based Saver's name for this variable. If it's
+ # already a SaveableObject we'd just get the checkpoint key back, so
+ # we leave full_name blank.
+ saver_dict = saveable_object_util.op_list_to_dict(
+ [maybe_saveable], convert_variable_to_tensor=False)
+ full_name, = saver_dict.keys()
+ saveables = tuple(saveable_object_util.saveable_objects_for_op(
+ op=maybe_saveable, name=attribute.checkpoint_key))
+ for saveable in saveables:
+ saveable.full_name = full_name
+ for saveable in saveables:
+ if attribute.checkpoint_key not in saveable.name:
+ raise AssertionError(
+ ("The object %s produced a SaveableObject with name '%s' for "
+ "attribute '%s'. Expected a name containing '%s'.")
+ % (checkpointable, name, saveable.name,
+ attribute.checkpoint_key))
+ if cached_attributes is not None:
+ cached_attributes[name] = saveables
+
+ optional_restore = None
+ for saveable in saveables:
+ if optional_restore is None:
+ optional_restore = saveable.optional_restore
+ else:
+ optional_restore = optional_restore and saveable.optional_restore
+
+ if hasattr(saveable, "full_name"):
+ attribute.full_name = saveable.full_name
+ if isinstance(saveable, base.PythonStateSaveable):
+ if feed_additions is None:
+ assert self._saveables_cache is None
+ # If we're not caching saveables, then we're either executing
+ # eagerly or building a static save/restore (e.g. for a
+ # SavedModel). In either case, we should embed the current Python
+ # state in the graph rather than relying on a feed dict.
+ saveable = saveable.freeze()
+ else:
+ saveable_feed_dict = saveable.feed_dict_additions()
+ for new_feed_key in saveable_feed_dict.keys():
+ if new_feed_key in feed_additions:
+ raise AssertionError(
+ ("The object %s tried to feed a value for the Tensor %s "
+ "when saving, but another object is already feeding a "
+ "value.")
+ % (checkpointable, new_feed_key))
+ feed_additions.update(saveable_feed_dict)
+ named_saveable_objects.append(saveable)
+ if optional_restore is None:
+ optional_restore = False
+ attribute.optional_restore = optional_restore
+
+ return named_saveable_objects, feed_additions
+
+ def _fill_object_graph_proto(self, checkpointable_objects,
+ node_ids,
+ slot_variables,
+ object_graph_proto=None):
+ """Name non-slot `Checkpointable`s and add them to `object_graph_proto`."""
+ if object_graph_proto is None:
+ object_graph_proto = (
+ checkpointable_object_graph_pb2.CheckpointableObjectGraph())
+ for checkpoint_id, checkpointable in enumerate(checkpointable_objects):
+ assert node_ids[checkpointable] == checkpoint_id
+ object_proto = object_graph_proto.nodes.add()
+ object_proto.slot_variables.extend(slot_variables.get(checkpointable, ()))
+ for child in self.list_dependencies(checkpointable):
+ child_proto = object_proto.children.add()
+ child_proto.node_id = node_ids[child.ref]
+ child_proto.local_name = child.name
+ return object_graph_proto
+
+ def _serialize_gathered_objects(self, checkpointable_objects, path_to_root,
+ object_map=None):
+ """Create SaveableObjects and protos for gathered objects."""
+ object_names = object_identity.ObjectIdentityDictionary()
+ for obj, path in path_to_root.items():
+ object_names[obj] = _object_prefix_from_path(path)
+ node_ids = object_identity.ObjectIdentityDictionary()
+ for node_id, node in enumerate(checkpointable_objects):
+ node_ids[node] = node_id
+ slot_variables = _serialize_slot_variables(
+ checkpointable_objects=checkpointable_objects,
+ node_ids=node_ids,
+ object_names=object_names)
+ object_graph_proto = self._fill_object_graph_proto(
+ checkpointable_objects=checkpointable_objects,
+ node_ids=node_ids,
+ slot_variables=slot_variables)
+ named_saveable_objects, feed_additions = (
+ self._add_attributes_to_object_graph(
+ checkpointable_objects=checkpointable_objects,
+ object_graph_proto=object_graph_proto,
+ node_ids=node_ids,
+ object_names=object_names,
+ object_map=object_map))
+ return named_saveable_objects, object_graph_proto, feed_additions
+
+ def serialize_object_graph(self):
+ """Determine checkpoint keys for variables and build a serialized graph.
+
+ Non-slot variables are keyed based on a shortest path from the root saveable
+ to the object which owns the variable (i.e. the one which called
+ `Checkpointable._add_variable` to create it).
+
+ Slot variables are keyed based on a shortest path to the variable being
+ slotted for, a shortest path to their optimizer, and the slot name.
+
+ Returns:
+ A tuple of (named_variables, object_graph_proto, feed_additions):
+ named_variables: A dictionary mapping names to variable objects.
+ object_graph_proto: A CheckpointableObjectGraph protocol buffer
+ containing the serialized object graph and variable references.
+ feed_additions: A dictionary mapping from Tensors to values which should
+ be fed when saving.
+
+ Raises:
+ ValueError: If there are invalid characters in an optimizer's slot names.
+ """
+ checkpointable_objects, path_to_root = self._breadth_first_traversal()
+ return self._serialize_gathered_objects(
+ checkpointable_objects, path_to_root)
+
+ def frozen_saveable_objects(self, object_map=None, to_graph=None):
+ """Creates SaveableObjects with the current object graph frozen."""
+ checkpointable_objects, path_to_root = self._breadth_first_traversal()
+ if to_graph:
+ target_context = to_graph.as_default
+ else:
+ target_context = ops.NullContextmanager
+ with target_context():
+ named_saveable_objects, graph_proto, _ = self._serialize_gathered_objects(
+ checkpointable_objects,
+ path_to_root,
+ object_map)
+ with ops.device("/cpu:0"):
+ object_graph_tensor = constant_op.constant(
+ graph_proto.SerializeToString(), dtype=dtypes.string)
+ named_saveable_objects.append(
+ base.NoRestoreSaveable(
+ tensor=object_graph_tensor,
+ name=base.OBJECT_GRAPH_PROTO_KEY))
+ return named_saveable_objects
+
+ def objects_ids_and_slot_variables(self):
+ """Traverse the object graph and list all accessible objects.
+
+ Looks for `Checkpointable` objects which are dependencies of
+ `root_checkpointable`. Includes slot variables only if the variable they are
+ slotting for and the optimizer are dependencies of `root_checkpointable`
+ (i.e. if they would be saved with a checkpoint).
+
+ Returns:
+ A tuple of (checkpointable objects, object -> node id, slot variables)
+ """
+ checkpointable_objects, path_to_root = self._breadth_first_traversal()
+ object_names = object_identity.ObjectIdentityDictionary()
+ for obj, path in path_to_root.items():
+ object_names[obj] = _object_prefix_from_path(path)
+ node_ids = object_identity.ObjectIdentityDictionary()
+ for node_id, node in enumerate(checkpointable_objects):
+ node_ids[node] = node_id
+ slot_variables = _serialize_slot_variables(
+ checkpointable_objects=checkpointable_objects,
+ node_ids=node_ids,
+ object_names=object_names)
+ return checkpointable_objects, node_ids, slot_variables
+
+ def list_objects(self):
+ """Traverse the object graph and list all accessible objects."""
+ checkpointable_objects, _, _ = self.objects_ids_and_slot_variables()
+ return checkpointable_objects
diff --git a/tensorflow/python/training/checkpointable/object_identity.py b/tensorflow/python/training/checkpointable/object_identity.py
new file mode 100644
index 0000000..2d3056b
--- /dev/null
+++ b/tensorflow/python/training/checkpointable/object_identity.py
@@ -0,0 +1,156 @@
+"""Utilities for collecting objects based on "is" comparison."""
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import weakref
+
+
+class _ObjectIdentityWrapper(object):
+ """Wraps an object, mapping __eq__ on wrapper to "is" on wrapped.
+
+ Since __eq__ is based on object identity, it's safe to also define __hash__
+ based on object ids. This lets us add unhashable types like checkpointable
+ _ListWrapper objects to object-identity collections.
+ """
+
+ def __init__(self, wrapped):
+ self._wrapped = wrapped
+
+ @property
+ def unwrapped(self):
+ return self._wrapped
+
+ def __eq__(self, other):
+ if isinstance(other, _ObjectIdentityWrapper):
+ return self._wrapped is other._wrapped # pylint: disable=protected-access
+ return self._wrapped is other
+
+ def __hash__(self):
+ # Wrapper id() is also fine for weakrefs. In fact, we rely on
+ # id(weakref.ref(a)) == id(weakref.ref(a)) and weakref.ref(a) is
+ # weakref.ref(a) in _WeakObjectIdentityWrapper.
+ return id(self._wrapped)
+
+
+class _WeakObjectIdentityWrapper(_ObjectIdentityWrapper):
+
+ def __init__(self, wrapped):
+ super(_WeakObjectIdentityWrapper, self).__init__(weakref.ref(wrapped))
+
+ @property
+ def unwrapped(self):
+ return self._wrapped()
+
+
+class ObjectIdentityDictionary(collections.MutableMapping):
+ """A mutable mapping data structure which compares using "is".
+
+ This is necessary because we have checkpointable objects (_ListWrapper) which
+ have behavior identical to built-in Python lists (including being unhashable
+ and comparing based on the equality of their contents by default).
+ """
+
+ def __init__(self):
+ self._storage = {}
+
+ def _wrap_key(self, key):
+ return _ObjectIdentityWrapper(key)
+
+ def __getitem__(self, key):
+ return self._storage[self._wrap_key(key)]
+
+ def __setitem__(self, key, value):
+ self._storage[self._wrap_key(key)] = value
+
+ def __delitem__(self, key):
+ del self._storage[self._wrap_key(key)]
+
+ def __len__(self):
+ return len(self._storage)
+
+ def __iter__(self):
+ for key in self._storage:
+ yield key.unwrapped
+
+
+class ObjectIdentityWeakKeyDictionary(ObjectIdentityDictionary):
+ """Like weakref.WeakKeyDictionary, but compares objects with "is"."""
+
+ def _wrap_key(self, key):
+ return _WeakObjectIdentityWrapper(key)
+
+ def __len__(self):
+ # Iterate, discarding old weak refs
+ return len(list(self._storage))
+
+ def __iter__(self):
+ keys = self._storage.keys()
+ for key in keys:
+ unwrapped = key.unwrapped
+ if unwrapped is None:
+ del self[key]
+ else:
+ yield unwrapped
+
+
+class ObjectIdentitySet(collections.MutableSet):
+ """Like the built-in set, but compares objects with "is"."""
+
+ def __init__(self, *args):
+ self._storage = set([self._wrap_key(obj) for obj in list(*args)])
+
+ def _wrap_key(self, key):
+ return _ObjectIdentityWrapper(key)
+
+ def __contains__(self, key):
+ return self._wrap_key(key) in self._storage
+
+ def discard(self, key):
+ self._storage.discard(self._wrap_key(key))
+
+ def add(self, key):
+ self._storage.add(self._wrap_key(key))
+
+ def __len__(self):
+ return len(self._storage)
+
+ def __iter__(self):
+ keys = list(self._storage)
+ for key in keys:
+ yield key.unwrapped
+
+
+class ObjectIdentityWeakSet(ObjectIdentitySet):
+ """Like weakref.WeakSet, but compares objects with "is"."""
+
+ def _wrap_key(self, key):
+ return _WeakObjectIdentityWrapper(key)
+
+ def __len__(self):
+ # Iterate, discarding old weak refs
+ return len([_ for _ in self])
+
+ def __iter__(self):
+ keys = list(self._storage)
+ for key in keys:
+ unwrapped = key.unwrapped
+ if unwrapped is None:
+ self.discard(key)
+ else:
+ yield unwrapped
diff --git a/tensorflow/python/training/checkpointable/tracking.py b/tensorflow/python/training/checkpointable/tracking.py
index 04fd554..83b0d8b 100644
--- a/tensorflow/python/training/checkpointable/tracking.py
+++ b/tensorflow/python/training/checkpointable/tracking.py
@@ -18,6 +18,8 @@
from __future__ import print_function
from tensorflow.python.eager import context
+from tensorflow.python.eager import def_function
+from tensorflow.python.eager import function as defun
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import resource_variable_ops
@@ -76,10 +78,36 @@
checkpointable=self, value=value, name=name)
super(AutoCheckpointable, self).__setattr__(name, value)
+ def __delattr__(self, name):
+ self._maybe_initialize_checkpointable()
+ if name in self._unconditional_dependency_names:
+ del self._unconditional_dependency_names[name]
+ for index, (dep_name, _) in enumerate(
+ self._unconditional_checkpoint_dependencies):
+ if dep_name == name:
+ del self._unconditional_checkpoint_dependencies[index]
+ break
+ super(AutoCheckpointable, self).__delattr__(name)
+
def _no_dependency(self, value):
"""Override to allow CheckpointableBase to disable dependency tracking."""
return data_structures.NoDependency(value)
+ def _list_functions_for_serialization(self):
+ """Return a dict of `Function`s of a checkpointable."""
+ functions = dict()
+ for attribute_name in dir(self):
+ try:
+ attribute_value = getattr(self, attribute_name, None)
+ except Exception: # pylint: disable=broad-except
+ # We really don't want to throw an exception just because some object's
+ # attribute accessor is broken.
+ attribute_value = None
+ if isinstance(attribute_value, (def_function.Function,
+ defun.ConcreteFunction)):
+ functions[attribute_name] = attribute_value
+ return functions
+
class ResourceTracker(object):
"""An object that tracks a list of resources."""
diff --git a/tensorflow/python/training/checkpointable/tracking_test.py b/tensorflow/python/training/checkpointable/tracking_test.py
index eb70919..87c6603 100644
--- a/tensorflow/python/training/checkpointable/tracking_test.py
+++ b/tensorflow/python/training/checkpointable/tracking_test.py
@@ -71,6 +71,21 @@
nodeps = NoDependencyModel()
self.assertEqual([nodeps], util.list_objects(nodeps))
+ def testRemoveDependency(self):
+ root = tracking.AutoCheckpointable()
+ root.a = tracking.AutoCheckpointable()
+ self.assertEqual(1, len(root._checkpoint_dependencies))
+ self.assertEqual(1, len(root._unconditional_checkpoint_dependencies))
+ self.assertIs(root.a, root._checkpoint_dependencies[0].ref)
+ del root.a
+ self.assertFalse(hasattr(root, "a"))
+ self.assertEqual(0, len(root._checkpoint_dependencies))
+ self.assertEqual(0, len(root._unconditional_checkpoint_dependencies))
+ root.a = tracking.AutoCheckpointable()
+ self.assertEqual(1, len(root._checkpoint_dependencies))
+ self.assertEqual(1, len(root._unconditional_checkpoint_dependencies))
+ self.assertIs(root.a, root._checkpoint_dependencies[0].ref)
+
def testListBasic(self):
a = tracking.AutoCheckpointable()
b = tracking.AutoCheckpointable()
diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py
index 129ad55..1ba9720 100644
--- a/tensorflow/python/training/checkpointable/util.py
+++ b/tensorflow/python/training/checkpointable/util.py
@@ -18,7 +18,6 @@
from __future__ import print_function
import abc
-import collections
import os
import weakref
@@ -39,13 +38,13 @@
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import checkpoint_management
-from tensorflow.python.training import optimizer as optimizer_v1
from tensorflow.python.training import saver as v1_saver_lib
from tensorflow.python.training.checkpointable import base
from tensorflow.python.training.checkpointable import data_structures
+from tensorflow.python.training.checkpointable import graph_view as graph_view_lib
+from tensorflow.python.training.checkpointable import object_identity
from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.training.saving import functional_saver
-from tensorflow.python.training.saving import saveable_object as saveable_object_lib
from tensorflow.python.training.saving import saveable_object_util
from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
@@ -53,27 +52,11 @@
from tensorflow.python.util.tf_export import tf_export
-_ESCAPE_CHAR = "." # For avoiding conflicts with user-specified names.
-
-# Keyword for identifying that the next bit of a checkpoint variable name is a
-# slot name. Checkpoint names for slot variables look like:
-#
-# <path to variable>/<_OPTIMIZER_SLOTS_NAME>/<path to optimizer>/<slot name>
-#
-# Where <path to variable> is a full path from the checkpoint root to the
-# variable being slotted for.
-_OPTIMIZER_SLOTS_NAME = _ESCAPE_CHAR + "OPTIMIZER_SLOT"
-# Keyword for separating the path to an object from the name of an
-# attribute in checkpoint names. Used like:
-# <path to variable>/<_OBJECT_ATTRIBUTES_NAME>/<name of attribute>
-_OBJECT_ATTRIBUTES_NAME = _ESCAPE_CHAR + "ATTRIBUTES"
-
-
class _CheckpointRestoreCoordinator(object):
"""Holds the status of an object-based checkpoint load."""
def __init__(self, object_graph_proto, save_path, save_path_tensor,
- restore_op_cache, saveable_object_cache):
+ restore_op_cache, graph_view):
"""Specify the checkpoint being loaded.
Args:
@@ -87,10 +70,8 @@
`_CheckpointRestoreCoordinator`s for the same Python objects, used to
look up restore ops by name to avoid re-creating them across multiple
`restore()` calls.
- saveable_object_cache: A mapping of checkpointable objects -> attribute
- names -> list(`SaveableObject`s), used when `SaveableObjects` must be
- referenced every restore (e.g. for Python state); otherwise they would
- create their own ops every restore.
+ graph_view: A graph_view_lib.ObjectGraphView object for the restored
+ objects.
"""
self.object_graph_proto = object_graph_proto
self.restore_uid = ops.uid()
@@ -108,7 +89,7 @@
# use them (for example because of inconsistent references when
# loading). Used to make status assertions fail when loading checkpoints
# that don't quite match.
- self.all_python_objects = _ObjectIdentityWeakSet()
+ self.all_python_objects = object_identity.ObjectIdentityWeakSet()
self.save_path_tensor = save_path_tensor
self.save_path_string = save_path
self.dtype_map = pywrap_tensorflow.NewCheckpointReader(
@@ -119,7 +100,7 @@
# this checkpoint.
self.restore_ops = []
self.restore_ops_by_name = restore_op_cache
- self.saveable_object_cache = saveable_object_cache
+ self.graph_view = graph_view
self.new_restore_ops_callback = None
# A mapping from optimizer proto ids to lists of slot variables to be
# restored when the optimizer is tracked. Only includes slot variables whose
@@ -345,464 +326,6 @@
return object_graph_proto
-class _ObjectIdentityWrapper(object):
- """Wraps an object, mapping __eq__ on wrapper to "is" on wrapped.
-
- Since __eq__ is based on object identity, it's safe to also define __hash__
- based on object ids. This lets us add unhashable types like checkpointable
- _ListWrapper objects to object-identity collections.
- """
-
- def __init__(self, wrapped):
- self._wrapped = wrapped
-
- @property
- def unwrapped(self):
- return self._wrapped
-
- def __eq__(self, other):
- if isinstance(other, _ObjectIdentityWrapper):
- return self._wrapped is other._wrapped # pylint: disable=protected-access
- return self._wrapped is other
-
- def __hash__(self):
- # Wrapper id() is also fine for weakrefs. In fact, we rely on
- # id(weakref.ref(a)) == id(weakref.ref(a)) and weakref.ref(a) is
- # weakref.ref(a) in _WeakObjectIdentityWrapper.
- return id(self._wrapped)
-
-
-class _WeakObjectIdentityWrapper(_ObjectIdentityWrapper):
-
- def __init__(self, wrapped):
- super(_WeakObjectIdentityWrapper, self).__init__(weakref.ref(wrapped))
-
- @property
- def unwrapped(self):
- return self._wrapped()
-
-
-class ObjectIdentityDictionary(collections.MutableMapping):
- """A mutable mapping data structure which compares using "is".
-
- This is necessary because we have checkpointable objects (_ListWrapper) which
- have behavior identical to built-in Python lists (including being unhashable
- and comparing based on the equality of their contents by default).
- """
-
- def __init__(self):
- self._storage = {}
-
- def _wrap_key(self, key):
- return _ObjectIdentityWrapper(key)
-
- def __getitem__(self, key):
- return self._storage[self._wrap_key(key)]
-
- def __setitem__(self, key, value):
- self._storage[self._wrap_key(key)] = value
-
- def __delitem__(self, key):
- del self._storage[self._wrap_key(key)]
-
- def __len__(self):
- return len(self._storage)
-
- def __iter__(self):
- for key in self._storage:
- yield key.unwrapped
-
-
-class _ObjectIdentityWeakKeyDictionary(ObjectIdentityDictionary):
- """Like weakref.WeakKeyDictionary, but compares objects with "is"."""
-
- def _wrap_key(self, key):
- return _WeakObjectIdentityWrapper(key)
-
- def __len__(self):
- # Iterate, discarding old weak refs
- return len(list(self._storage))
-
- def __iter__(self):
- keys = self._storage.keys()
- for key in keys:
- unwrapped = key.unwrapped
- if unwrapped is None:
- del self[key]
- else:
- yield unwrapped
-
-
-class _ObjectIdentitySet(collections.MutableSet):
- """Like the built-in set, but compares objects with "is"."""
-
- def __init__(self, *args):
- self._storage = set([self._wrap_key(obj) for obj in list(*args)])
-
- def _wrap_key(self, key):
- return _ObjectIdentityWrapper(key)
-
- def __contains__(self, key):
- return self._wrap_key(key) in self._storage
-
- def discard(self, key):
- self._storage.discard(self._wrap_key(key))
-
- def add(self, key):
- self._storage.add(self._wrap_key(key))
-
- def __len__(self):
- return len(self._storage)
-
- def __iter__(self):
- keys = list(self._storage)
- for key in keys:
- yield key.unwrapped
-
-
-class _ObjectIdentityWeakSet(_ObjectIdentitySet):
- """Like weakref.WeakSet, but compares objects with "is"."""
-
- def _wrap_key(self, key):
- return _WeakObjectIdentityWrapper(key)
-
- def __len__(self):
- # Iterate, discarding old weak refs
- return len([_ for _ in self])
-
- def __iter__(self):
- keys = list(self._storage)
- for key in keys:
- unwrapped = key.unwrapped
- if unwrapped is None:
- self.discard(key)
- else:
- yield unwrapped
-
-
-def _breadth_first_checkpointable_traversal(root_checkpointable):
- """Find shortest paths to all variables owned by dependencies of root."""
- bfs_sorted = []
- to_visit = collections.deque([root_checkpointable])
- path_to_root = ObjectIdentityDictionary()
- path_to_root[root_checkpointable] = ()
- while to_visit:
- current_checkpointable = to_visit.popleft()
- if isinstance(current_checkpointable, tracking.NotCheckpointable):
- raise NotImplementedError(
- ("The object %s does not support object-based saving. File a feature "
- "request if this limitation bothers you. In the meantime, you can "
- "remove the dependency on this object and save everything else.")
- % (current_checkpointable,))
- current_checkpointable._maybe_initialize_checkpointable() # pylint: disable=protected-access
- bfs_sorted.append(current_checkpointable)
- for child_checkpointable in (
- current_checkpointable._checkpoint_dependencies): # pylint: disable=protected-access
- if child_checkpointable.ref not in path_to_root:
- path_to_root[child_checkpointable.ref] = (
- path_to_root[current_checkpointable] + (child_checkpointable,))
- to_visit.append(child_checkpointable.ref)
- return bfs_sorted, path_to_root
-
-
-def _escape_local_name(name):
- # We need to support slashes in local names for compatibility, since this
- # naming scheme is being patched in to things like Layer.add_variable where
- # slashes were previously accepted. We also want to use slashes to indicate
- # edges traversed to reach the variable, so we escape forward slashes in
- # names.
- return (name.replace(_ESCAPE_CHAR, _ESCAPE_CHAR + _ESCAPE_CHAR)
- .replace(r"/", _ESCAPE_CHAR + "S"))
-
-
-def _object_prefix_from_path(path_to_root):
- return "/".join(
- (_escape_local_name(checkpointable.name)
- for checkpointable in path_to_root))
-
-
-def _slot_variable_naming_for_optimizer(optimizer_path):
- """Make a function for naming slot variables in an optimizer."""
- # Name slot variables:
- #
- # <variable name>/<_OPTIMIZER_SLOTS_NAME>/<optimizer path>/<slot name>
- #
- # where <variable name> is exactly the checkpoint name used for the original
- # variable, including the path from the checkpoint root and the local name in
- # the object which owns it. Note that we only save slot variables if the
- # variable it's slotting for is also being saved.
-
- optimizer_identifier = "/%s/%s/" % (_OPTIMIZER_SLOTS_NAME, optimizer_path)
-
- def _name_slot_variable(variable_path, slot_name):
- """With an optimizer specified, name a slot variable."""
- return (variable_path
- + optimizer_identifier
- + _escape_local_name(slot_name))
-
- return _name_slot_variable
-
-
-def _serialize_slot_variables(checkpointable_objects, node_ids, object_names):
- """Gather and name slot variables."""
- non_slot_objects = list(checkpointable_objects)
- slot_variables = ObjectIdentityDictionary()
- for checkpointable in non_slot_objects:
- if (isinstance(checkpointable, optimizer_v1.Optimizer)
- # TODO(b/110718070): Fix Keras imports.
- or hasattr(checkpointable, "_create_or_restore_slot_variable")):
- naming_scheme = _slot_variable_naming_for_optimizer(
- optimizer_path=object_names[checkpointable])
- slot_names = checkpointable.get_slot_names()
- for slot_name in slot_names:
- for original_variable_node_id, original_variable in enumerate(
- non_slot_objects):
- try:
- slot_variable = checkpointable.get_slot(
- original_variable, slot_name)
- except (AttributeError, KeyError):
- slot_variable = None
- if slot_variable is None:
- continue
- slot_variable._maybe_initialize_checkpointable() # pylint: disable=protected-access
- if slot_variable._checkpoint_dependencies: # pylint: disable=protected-access
- # TODO(allenl): Gather dependencies of slot variables.
- raise NotImplementedError(
- "Currently only variables with no dependencies can be saved as "
- "slot variables. File a feature request if this limitation "
- "bothers you.")
- if slot_variable in node_ids:
- raise NotImplementedError(
- "A slot variable was re-used as a dependency of a "
- "Checkpointable object. This is not currently allowed. File a "
- "feature request if this limitation bothers you.")
- checkpoint_name = naming_scheme(
- variable_path=object_names[original_variable],
- slot_name=slot_name)
- object_names[slot_variable] = checkpoint_name
- slot_variable_node_id = len(checkpointable_objects)
- node_ids[slot_variable] = slot_variable_node_id
- checkpointable_objects.append(slot_variable)
- slot_variable_proto = (
- checkpointable_object_graph_pb2.CheckpointableObjectGraph
- .CheckpointableObject.SlotVariableReference(
- slot_name=slot_name,
- original_variable_node_id=original_variable_node_id,
- slot_variable_node_id=slot_variable_node_id))
- slot_variables.setdefault(checkpointable, []).append(
- slot_variable_proto)
- return slot_variables
-
-
-def _add_attributes_to_object_graph(
- checkpointable_objects, object_graph_proto, node_ids, object_names,
- saveables_cache, object_map):
- """Create SaveableObjects and corresponding SerializedTensor protos."""
- named_saveable_objects = []
- if saveables_cache is None:
- # No SaveableObject caching. Either we're executing eagerly, or building a
- # static save which is specialized to the current Python state.
- feed_additions = None
- else:
- # If we are caching SaveableObjects, we need to build up a feed_dict with
- # functions computing volatile Python state to be saved with the checkpoint.
- feed_additions = {}
- for checkpoint_id, (checkpointable, object_proto) in enumerate(
- zip(checkpointable_objects, object_graph_proto.nodes)):
- assert node_ids[checkpointable] == checkpoint_id
- object_name = object_names[checkpointable]
- if object_map:
- object_to_save = object_map.get(checkpointable, checkpointable)
- else:
- object_to_save = checkpointable
- if saveables_cache is not None:
- cached_attributes = saveables_cache.setdefault(object_to_save, {})
- else:
- cached_attributes = None
-
- for name, saveable_factory in (
- object_to_save._gather_saveables_for_checkpoint().items()): # pylint: disable=protected-access
- attribute = object_proto.attributes.add()
- attribute.name = name
- attribute.checkpoint_key = "%s/%s/%s" % (
- object_name, _OBJECT_ATTRIBUTES_NAME, _escape_local_name(name))
- if cached_attributes is None:
- saveables = None
- else:
- saveables = cached_attributes.get(name, None)
- if saveables is not None:
- for saveable in saveables:
- if attribute.checkpoint_key not in saveable.name:
- # The checkpoint key for this SaveableObject is different. We need
- # to re-create it.
- saveables = None
- del cached_attributes[name]
- break
- if saveables is None:
- if callable(saveable_factory):
- maybe_saveable = saveable_factory(name=attribute.checkpoint_key)
- else:
- maybe_saveable = saveable_factory
- if isinstance(maybe_saveable, saveable_object_lib.SaveableObject):
- saveables = (maybe_saveable,)
- else:
- # Figure out the name-based Saver's name for this variable. If it's
- # already a SaveableObject we'd just get the checkpoint key back, so
- # we leave full_name blank.
- saver_dict = saveable_object_util.op_list_to_dict(
- [maybe_saveable], convert_variable_to_tensor=False)
- full_name, = saver_dict.keys()
- saveables = tuple(saveable_object_util.saveable_objects_for_op(
- op=maybe_saveable, name=attribute.checkpoint_key))
- for saveable in saveables:
- saveable.full_name = full_name
- for saveable in saveables:
- if attribute.checkpoint_key not in saveable.name:
- raise AssertionError(
- ("The object %s produced a SaveableObject with name '%s' for "
- "attribute '%s'. Expected a name containing '%s'.")
- % (checkpointable, name, saveable.name,
- attribute.checkpoint_key))
- if cached_attributes is not None:
- cached_attributes[name] = saveables
-
- optional_restore = None
- for saveable in saveables:
- if optional_restore is None:
- optional_restore = saveable.optional_restore
- else:
- optional_restore = optional_restore and saveable.optional_restore
-
- if hasattr(saveable, "full_name"):
- attribute.full_name = saveable.full_name
- if isinstance(saveable, base.PythonStateSaveable):
- if feed_additions is None:
- assert saveables_cache is None
- # If we're not caching saveables, then we're either executing
- # eagerly or building a static save/restore (e.g. for a
- # SavedModel). In either case, we should embed the current Python
- # state in the graph rather than relying on a feed dict.
- saveable = saveable.freeze()
- else:
- saveable_feed_dict = saveable.feed_dict_additions()
- for new_feed_key in saveable_feed_dict.keys():
- if new_feed_key in feed_additions:
- raise AssertionError(
- ("The object %s tried to feed a value for the Tensor %s "
- "when saving, but another object is already feeding a "
- "value.")
- % (checkpointable, new_feed_key))
- feed_additions.update(saveable_feed_dict)
- named_saveable_objects.append(saveable)
- if optional_restore is None:
- optional_restore = False
- attribute.optional_restore = optional_restore
-
- return named_saveable_objects, feed_additions
-
-
-def fill_object_graph_proto(checkpointable_objects,
- node_ids,
- slot_variables,
- object_graph_proto=None):
- """Name non-slot `Checkpointable`s and add them to `object_graph_proto`."""
- if object_graph_proto is None:
- object_graph_proto = (
- checkpointable_object_graph_pb2.CheckpointableObjectGraph())
- for checkpoint_id, checkpointable in enumerate(checkpointable_objects):
- assert node_ids[checkpointable] == checkpoint_id
- object_proto = object_graph_proto.nodes.add()
- object_proto.slot_variables.extend(slot_variables.get(checkpointable, ()))
- for child in checkpointable._checkpoint_dependencies: # pylint: disable=protected-access
- child_proto = object_proto.children.add()
- child_proto.node_id = node_ids[child.ref]
- child_proto.local_name = child.name
- return object_graph_proto
-
-
-def _serialize_gathered_objects(
- checkpointable_objects, path_to_root, saveables_cache, object_map):
- """Create SaveableObjects and protos for gathered objects."""
- object_names = ObjectIdentityDictionary()
- for obj, path in path_to_root.items():
- object_names[obj] = _object_prefix_from_path(path)
- node_ids = ObjectIdentityDictionary()
- for node_id, node in enumerate(checkpointable_objects):
- node_ids[node] = node_id
- slot_variables = _serialize_slot_variables(
- checkpointable_objects=checkpointable_objects,
- node_ids=node_ids,
- object_names=object_names)
- object_graph_proto = fill_object_graph_proto(
- checkpointable_objects=checkpointable_objects,
- node_ids=node_ids,
- slot_variables=slot_variables)
- named_saveable_objects, feed_additions = _add_attributes_to_object_graph(
- checkpointable_objects=checkpointable_objects,
- object_graph_proto=object_graph_proto,
- node_ids=node_ids,
- object_names=object_names,
- saveables_cache=saveables_cache,
- object_map=object_map)
- return named_saveable_objects, object_graph_proto, feed_additions
-
-
-def _serialize_object_graph(root_checkpointable, saveables_cache):
- """Determine checkpoint keys for variables and build a serialized graph.
-
- Non-slot variables are keyed based on a shortest path from the root saveable
- to the object which owns the variable (i.e. the one which called
- `Checkpointable._add_variable` to create it).
-
- Slot variables are keyed based on a shortest path to the variable being
- slotted for, a shortest path to their optimizer, and the slot name.
-
- Args:
- root_checkpointable: A `Checkpointable` object whose variables (including
- the variables of dependencies, recursively) should be saved.
- saveables_cache: A dictionary mapping `Checkpointable` objects -> attribute
- names -> SaveableObjects, used to avoid re-creating SaveableObjects when
- graph building.
-
- Returns:
- A tuple of (named_variables, object_graph_proto, feed_additions):
- named_variables: A dictionary mapping names to variable objects.
- object_graph_proto: A CheckpointableObjectGraph protocol buffer containing
- the serialized object graph and variable references.
- feed_additions: A dictionary mapping from Tensors to values which should
- be fed when saving.
-
- Raises:
- ValueError: If there are invalid characters in an optimizer's slot names.
- """
- checkpointable_objects, path_to_root = (
- _breadth_first_checkpointable_traversal(root_checkpointable))
- return _serialize_gathered_objects(
- checkpointable_objects, path_to_root, saveables_cache, object_map=None)
-
-
-def named_saveables(root_checkpointable):
- """Gather list of all SaveableObjects in the Checkpointable object."""
- return _serialize_object_graph(root_checkpointable, None)[0]
-
-
-def find_objects(root_checkpointable):
- """Find and number objects which are dependencies of `root_checkpointable`."""
- checkpointable_objects, path_to_root = (
- _breadth_first_checkpointable_traversal(root_checkpointable))
- object_names = ObjectIdentityDictionary()
- for obj, path in path_to_root.items():
- object_names[obj] = _object_prefix_from_path(path)
- node_ids = ObjectIdentityDictionary()
- for node_id, node in enumerate(checkpointable_objects):
- node_ids[node] = node_id
- slot_variables = _serialize_slot_variables(
- checkpointable_objects=checkpointable_objects,
- node_ids=node_ids,
- object_names=object_names)
- return checkpointable_objects, node_ids, slot_variables
-
-
def list_objects(root_checkpointable):
"""Traverse the object graph and list all accessible objects.
@@ -817,8 +340,7 @@
Returns:
A flat list of objects.
"""
- checkpointable_objects, _, _ = find_objects(root_checkpointable)
- return checkpointable_objects
+ return graph_view_lib.ObjectGraphView(root_checkpointable).list_objects()
def gather_initializers(root_checkpointable):
@@ -999,10 +521,10 @@
See `Saver.restore` for usage examples.
"""
- def __init__(self, checkpoint, feed_dict, root_checkpointable):
+ def __init__(self, checkpoint, feed_dict, graph_view):
self._checkpoint = checkpoint
self._feed_dict = feed_dict
- self._root_checkpointable = root_checkpointable
+ self._graph_view = graph_view
def assert_consumed(self):
"""Asserts that all objects in the checkpoint have been created/matched.
@@ -1055,7 +577,7 @@
and checkpointable._update_uid < self._checkpoint.restore_uid): # pylint: disable=protected-access
raise AssertionError(
"Object not assigned a value from checkpoint: %s" % (node,))
- for checkpointable_object in list_objects(self._root_checkpointable):
+ for checkpointable_object in self._graph_view.list_objects():
# Remove data structures that do not contain any variables from
# restoration checks.
if (isinstance(checkpointable_object,
@@ -1064,8 +586,9 @@
continue
self._checkpoint.all_python_objects.add(checkpointable_object)
unused_python_objects = (
- _ObjectIdentitySet(self._checkpoint.all_python_objects)
- - _ObjectIdentitySet(self._checkpoint.object_by_proto_id.values()))
+ object_identity.ObjectIdentitySet(self._checkpoint.all_python_objects)
+ - object_identity.ObjectIdentitySet(
+ self._checkpoint.object_by_proto_id.values()))
if unused_python_objects:
raise AssertionError(
("Some Python objects were not bound to checkpointed values, likely "
@@ -1075,12 +598,14 @@
def assert_nontrivial_match(self):
"""Raises an exception if only the root object matched."""
- for checkpointable_object in list_objects(self._root_checkpointable):
+ for checkpointable_object in self._graph_view.list_objects():
self._checkpoint.all_python_objects.add(checkpointable_object)
if len(self._checkpoint.object_by_proto_id) <= 1:
unused_python_objects = (
- _ObjectIdentitySet(self._checkpoint.all_python_objects)
- - _ObjectIdentitySet(self._checkpoint.object_by_proto_id.values()))
+ object_identity.ObjectIdentitySet(
+ self._checkpoint.all_python_objects)
+ - object_identity.ObjectIdentitySet(
+ self._checkpoint.object_by_proto_id.values()))
if unused_python_objects:
raise AssertionError(
("Nothing except the root object matched a checkpointed value. "
@@ -1090,7 +615,7 @@
else:
raise AssertionError(
"Nothing to load. No dependencies have been added to %s yet." % (
- self._root_checkpointable,))
+ self._graph_view.root,))
return self
def run_restore_ops(self, session=None):
@@ -1120,8 +645,8 @@
return # Initialization and restoration ops are run eagerly
if session is None:
session = ops.get_default_session()
- all_objects = list_objects(self._root_checkpointable)
- already_initialized_objects = _ObjectIdentitySet(
+ all_objects = self._graph_view.list_objects()
+ already_initialized_objects = object_identity.ObjectIdentitySet(
self._checkpoint.object_by_proto_id.values())
initializers_for_non_restored_variables = [
c.initializer for c in all_objects
@@ -1143,9 +668,9 @@
otherwise.
"""
- def __init__(self, root_checkpointable, restore_uid):
+ def __init__(self, graph_view, restore_uid):
self._restore_uid = restore_uid
- self._root_checkpointable = root_checkpointable
+ self._graph_view = graph_view
def assert_consumed(self):
"""Assertion for consistency with `CheckpointLoadStatus`. Always fails."""
@@ -1193,7 +718,7 @@
return # run eagerly
if session is None:
session = ops.get_default_session()
- checkpointable_objects = list_objects(self._root_checkpointable)
+ checkpointable_objects = self._graph_view.list_objects()
initializers = [
c.initializer for c in checkpointable_objects
if hasattr(c, "initializer") and c.initializer is not None
@@ -1218,9 +743,9 @@
# interferes with isinstance checks.
@deprecation.deprecated(
date=None, instructions=_DEPRECATED_RESTORE_INSTRUCTIONS)
- def __init__(self, checkpoint, root_checkpointable):
+ def __init__(self, checkpoint, graph_view):
self._checkpoint = checkpoint
- self._root_checkpointable = root_checkpointable
+ self._graph_view = graph_view
def assert_consumed(self):
"""Raises an exception if any variables/objects are unmatched."""
@@ -1229,7 +754,7 @@
raise AssertionError(
"Some objects had attributes which were not restored: %s"
% (unused_attributes,))
- for checkpointable in list_objects(self._root_checkpointable):
+ for checkpointable in self._graph_view.list_objects():
# pylint: disable=protected-access
checkpointable._maybe_initialize_checkpointable()
if checkpointable._update_uid < self._checkpoint.restore_uid:
@@ -1255,7 +780,7 @@
def _gather_saveable_objects(self):
"""Walk the object graph, using global names for SaveableObjects."""
- objects = list_objects(self._root_checkpointable)
+ objects = self._graph_view.list_objects()
saveable_objects = []
for checkpointable in objects:
# pylint: disable=protected-access
@@ -1322,17 +847,13 @@
so allow additional program transformations.
"""
- def __init__(self, root_checkpointable):
+ def __init__(self, graph_view):
"""Configure saving.
Args:
- root_checkpointable: The root of the object graph to save/restore. This
- object and all of its dependencies are saved in the checkpoint. When
- restoring, objects are matched and restored starting from this root.
+ graph_view: A `GraphView` object containing a description of the object
+ graph to save.
"""
- # Allow passing in a weak reference to avoid reference cycles when
- # `Checkpointable` objects save themselves.
- self._root_checkpointable_ref = root_checkpointable
# The file prefix placeholder is created lazily when graph building (and not
# at all when executing eagerly) to avoid creating ops in the constructor
# (when they may never be necessary).
@@ -1346,34 +867,13 @@
# Op caching for restore, shared between _CheckpointRestoreCoordinators
self._restore_op_cache = {}
-
- if context.executing_eagerly():
- # SaveableObjects are always recreated when executing eagerly.
- self._saveable_object_cache = None
- else:
- # Maps Checkpointable objects -> attribute names -> list(SaveableObjects),
- # to avoid re-creating SaveableObjects when graph building.
- self._saveable_object_cache = _ObjectIdentityWeakKeyDictionary()
-
- @property
- def _root_checkpointable(self):
- if isinstance(self._root_checkpointable_ref, weakref.ref):
- derefed = self._root_checkpointable_ref()
- assert derefed is not None
- return derefed
- else:
- return self._root_checkpointable_ref
+ self._graph_view = graph_view
def _gather_saveables(
- self, object_graph_tensor=None, saveable_object_cache=None):
+ self, object_graph_tensor=None):
"""Wraps _serialize_object_graph to include the object graph proto."""
- assert ((object_graph_tensor is None and saveable_object_cache is None)
- or (object_graph_tensor is not None
- and saveable_object_cache is not None))
(named_saveable_objects, graph_proto,
- feed_additions) = _serialize_object_graph(
- self._root_checkpointable,
- saveables_cache=saveable_object_cache)
+ feed_additions) = self._graph_view.serialize_object_graph()
if object_graph_tensor is None:
with ops.device("/cpu:0"):
object_graph_tensor = constant_op.constant(
@@ -1388,52 +888,16 @@
name=base.OBJECT_GRAPH_PROTO_KEY))
return named_saveable_objects, graph_proto, feed_additions
- def gather_objects(self, object_map=None, to_graph=None):
- """Creates SaveableObjects with the current object graph frozen."""
- checkpointable_objects, path_to_root = (
- _breadth_first_checkpointable_traversal(self._root_checkpointable))
- if to_graph:
- target_context = to_graph.as_default
- else:
- target_context = ops.NullContextmanager
- with target_context():
- named_saveable_objects, graph_proto, _ = _serialize_gathered_objects(
- checkpointable_objects,
- path_to_root,
- saveables_cache=None,
- object_map=object_map)
- with ops.device("/cpu:0"):
- object_graph_tensor = constant_op.constant(
- graph_proto.SerializeToString(), dtype=dtypes.string)
- named_saveable_objects.append(
- base.NoRestoreSaveable(
- tensor=object_graph_tensor,
- name=base.OBJECT_GRAPH_PROTO_KEY))
- return named_saveable_objects
-
- def freeze(self, object_map=None, to_graph=None):
- named_saveable_objects = self.gather_objects(
- object_map=object_map, to_graph=to_graph)
- return functional_saver.Saver(named_saveable_objects)
-
def _save_cached_when_graph_building(
self,
file_prefix,
- object_graph_tensor=None,
- saveable_object_cache=None):
+ object_graph_tensor=None):
"""Create or retrieve save ops.
- When graph building, `saveable_object_cache` will typically be non-`None`,
- meaning that existing `SaveableObject`s are re-used across calls to
- `_prepare_save` even if the object graph has grown. This avoids
- unnecessarily re-creating save ops.
-
Args:
file_prefix: The prefix for saved checkpoint files.
object_graph_tensor: A `Tensor` to which the current object graph will be
fed.
- saveable_object_cache: A dictionary; if specified, used to cache
- `SaveableObject`s.
Returns:
A two-element tuple with a filename tensor and a feed_dict of tensors to
@@ -1443,8 +907,7 @@
"""
(named_saveable_objects, graph_proto,
feed_additions) = self._gather_saveables(
- object_graph_tensor=object_graph_tensor,
- saveable_object_cache=saveable_object_cache)
+ object_graph_tensor=object_graph_tensor)
if (self._last_save_object_graph != graph_proto
# When executing eagerly, we need to re-create SaveableObjects each time
# save() is called so they pick up new Tensors passed to their
@@ -1502,8 +965,7 @@
file_io.recursive_create_dir(os.path.dirname(file_prefix))
save_path, new_feed_additions = self._save_cached_when_graph_building(
file_prefix=file_prefix_tensor,
- object_graph_tensor=object_graph_tensor,
- saveable_object_cache=self._saveable_object_cache)
+ object_graph_tensor=object_graph_tensor)
if new_feed_additions:
feed_dict.update(new_feed_additions)
if not graph_building:
@@ -1576,7 +1038,7 @@
object is returned which runs restore ops from a name-based saver.
"""
if save_path is None:
- return InitializationOnlyStatus(self._root_checkpointable, ops.uid())
+ return InitializationOnlyStatus(self._graph_view, ops.uid())
reader = pywrap_tensorflow.NewCheckpointReader(save_path)
graph_building = not context.executing_eagerly()
if graph_building:
@@ -1592,7 +1054,7 @@
restore_coordinator = _NameBasedRestoreCoordinator(
save_path=save_path, dtype_map=dtype_map)
if not graph_building:
- for existing_checkpointable in list_objects(self._root_checkpointable):
+ for existing_checkpointable in self._graph_view.list_objects():
# pylint: disable=protected-access
existing_checkpointable._maybe_initialize_checkpointable()
existing_checkpointable._name_based_restores.add(restore_coordinator)
@@ -1600,7 +1062,7 @@
restore_coordinator)
# pylint: enable=protected-access
return NameBasedSaverStatus(
- restore_coordinator, root_checkpointable=self._root_checkpointable)
+ restore_coordinator, graph_view=self._graph_view)
if graph_building:
if self._file_prefix_placeholder is None:
@@ -1620,12 +1082,12 @@
save_path=save_path,
save_path_tensor=file_prefix_tensor,
restore_op_cache=self._restore_op_cache,
- saveable_object_cache=self._saveable_object_cache)
- base._CheckpointPosition( # pylint: disable=protected-access
- checkpoint=checkpoint, proto_id=0).restore(self._root_checkpointable)
+ graph_view=self._graph_view)
+ base.CheckpointPosition(checkpoint=checkpoint, proto_id=0).restore(
+ self._graph_view.root)
load_status = CheckpointLoadStatus(
checkpoint,
- root_checkpointable=self._root_checkpointable,
+ graph_view=self._graph_view,
feed_dict=file_prefix_feed_dict)
return load_status
@@ -1648,10 +1110,23 @@
root_checkpointable: A checkpointable object to save.
Returns:
- A `tf.train.Saver` which saves object-based checkpoints for the object graph
- frozen at the time `frozen_saver` was called.
+ A saver which saves object-based checkpoints for the object graph frozen at
+ the time `frozen_saver` was called.
"""
- return CheckpointableSaver(root_checkpointable).freeze()
+ named_saveable_objects = graph_view_lib.ObjectGraphView(
+ root_checkpointable).frozen_saveable_objects()
+ return functional_saver.Saver(named_saveable_objects)
+
+
+def saver_with_op_caching(obj):
+ """A CheckpointableSaver with a SaveableObject cache when graph building."""
+ if context.executing_eagerly():
+ saveables_cache = None
+ else:
+ saveables_cache = object_identity.ObjectIdentityWeakKeyDictionary()
+ return CheckpointableSaver(graph_view_lib.ObjectGraphView(
+ weakref.ref(obj),
+ saveables_cache=saveables_cache))
@tf_export("train.Checkpoint")
@@ -1767,7 +1242,7 @@
setattr(self, k, v)
self._save_counter = None # Created lazily for restore-on-create.
self._save_assign_op = None
- self._saver = CheckpointableSaver(weakref.ref(self))
+ self._saver = saver_with_op_caching(self)
def _maybe_create_save_counter(self):
"""Create a save counter if it does not yet exist."""
diff --git a/tensorflow/python/training/checkpointable/util_test.py b/tensorflow/python/training/checkpointable/util_test.py
index cef1075..4f9abcc 100644
--- a/tensorflow/python/training/checkpointable/util_test.py
+++ b/tensorflow/python/training/checkpointable/util_test.py
@@ -47,6 +47,7 @@
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
from tensorflow.python.training.checkpointable import base
+from tensorflow.python.training.checkpointable import graph_view
from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.training.checkpointable import util as checkpointable_utils
@@ -125,8 +126,8 @@
# The .name attribute may be globally influenced, but the checkpoint name
# won't be (tested below).
self.assertEqual("duplicate_1:0", duplicate.name)
- named_variables, _, _ = checkpointable_utils._serialize_object_graph(
- obj, saveables_cache=None)
+ named_variables, _, _ = (
+ graph_view.ObjectGraphView(obj).serialize_object_graph())
expected_checkpoint_names = (
"a_variable/.ATTRIBUTES/VARIABLE_VALUE",
"bare_initializer/.ATTRIBUTES/VARIABLE_VALUE",
@@ -274,9 +275,8 @@
self.evaluate(checkpointable_utils.gather_initializers(
root_checkpointable))
self.evaluate(train_op)
- named_variables, serialized_graph, _ = (
- checkpointable_utils._serialize_object_graph(
- root_checkpointable, saveables_cache=None))
+ named_variables, serialized_graph, _ = graph_view.ObjectGraphView(
+ root_checkpointable).serialize_object_graph()
expected_slot_keys = (
"model/_second/kernel/.OPTIMIZER_SLOT/optimizer/m",
"model/_second/kernel/.OPTIMIZER_SLOT/optimizer/v",
@@ -656,8 +656,8 @@
root = tracking.AutoCheckpointable()
checkpointable_utils.add_variable(
root, name=name, shape=[1, 2], dtype=dtypes.float64)
- (named_variable,), _, _ = checkpointable_utils._serialize_object_graph(
- root, saveables_cache=None)
+ (named_variable,), _, _ = graph_view.ObjectGraphView(
+ root).serialize_object_graph()
with ops.name_scope("root/" + named_variable.name):
pass # Make sure we can use this as an op name if we prefix it.
return named_variable.name
@@ -678,8 +678,8 @@
leaf = tracking.AutoCheckpointable()
root.leaf = leaf
checkpointable_utils.add_variable(leaf, name="v", shape=[])
- (named_variable,), _, _ = checkpointable_utils._serialize_object_graph(
- root, saveables_cache=None)
+ (named_variable,), _, _ = graph_view.ObjectGraphView(
+ root).serialize_object_graph()
self.assertEqual(r"leaf/v/.ATTRIBUTES/VARIABLE_VALUE", named_variable.name)
@test_util.run_in_graph_and_eager_modes
@@ -689,8 +689,8 @@
# Dots are escaped, which avoids conflicts with reserved names.
root._track_checkpointable(leaf, name=".ATTRIBUTES")
checkpointable_utils.add_variable(checkpointable=leaf, name="a", shape=[])
- (named_variable,), _, _ = checkpointable_utils._serialize_object_graph(
- root, saveables_cache=None)
+ (named_variable,), _, _ = graph_view.ObjectGraphView(
+ root).serialize_object_graph()
self.assertEqual("..ATTRIBUTES/a/.ATTRIBUTES/VARIABLE_VALUE",
named_variable.name)
@@ -732,7 +732,7 @@
self.var = checkpointable_utils.add_variable(
self, "var", initializer=0.)
- class LateDependencies(tracking.AutoCheckpointable):
+ class LateDependencies(checkpointable_utils.Checkpoint):
def add_dep(self):
self.dep = Dependency()
@@ -743,11 +743,9 @@
self.evaluate(state_ops.assign(original.dep.var, 123.))
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
- save_path = checkpointable_utils.CheckpointableSaver(
- original).save(checkpoint_prefix)
+ save_path = original.save(checkpoint_prefix)
load_into = LateDependencies()
- status = checkpointable_utils.CheckpointableSaver(
- load_into).restore(save_path)
+ status = load_into.restore(save_path)
status.assert_existing_objects_matched()
with self.assertRaises(AssertionError):
status.assert_consumed()
@@ -765,7 +763,7 @@
self.var = checkpointable_utils.add_variable(
self, "var", initializer=0.)
- class DepAfterVar(tracking.AutoCheckpointable):
+ class DepAfterVar(checkpointable_utils.Checkpoint):
def add_dep(self):
dep = Dependency()
@@ -777,12 +775,10 @@
self.evaluate(state_ops.assign(dep_after_var.dep.var, -14.))
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
- save_path = checkpointable_utils.CheckpointableSaver(dep_after_var).save(
- checkpoint_prefix)
+ save_path = dep_after_var.save(checkpoint_prefix)
loaded_dep_after_var = DepAfterVar()
- status = checkpointable_utils.CheckpointableSaver(
- loaded_dep_after_var).restore(save_path)
+ status = loaded_dep_after_var.restore(save_path)
loaded_dep_after_var.add_dep()
status.assert_consumed()
status.run_restore_ops()
@@ -792,7 +788,7 @@
def testDeferredSlotRestoration(self):
checkpoint_directory = self.get_temp_dir()
- root = tracking.AutoCheckpointable()
+ root = checkpointable_utils.Checkpoint()
root.var = checkpointable_utils.add_variable(
root, name="var", initializer=0.)
optimizer = adam.Adam(0.1)
@@ -806,22 +802,18 @@
checkpointable_utils.Checkpoint(root=root, optimizer=optimizer)))
self.evaluate(train_op)
self.evaluate(state_ops.assign(root.var, 12.))
- no_slots_path = checkpointable_utils.CheckpointableSaver(root).save(
- os.path.join(checkpoint_directory, "no_slots"))
+ no_slots_path = root.save(os.path.join(checkpoint_directory, "no_slots"))
root.optimizer = optimizer
self.evaluate(state_ops.assign(root.var, 13.))
self.evaluate(state_ops.assign(
optimizer.get_slot(slot_name="m", var=root.var),
14.))
- slots_path = checkpointable_utils.CheckpointableSaver(root).save(
- os.path.join(checkpoint_directory, "with_slots"))
- new_root = tracking.AutoCheckpointable()
+ slots_path = root.save(os.path.join(checkpoint_directory, "with_slots"))
+ new_root = checkpointable_utils.Checkpoint()
# Load the slot-containing checkpoint (deferred), then immediately overwrite
# the non-slot variable (also deferred).
- slot_status = checkpointable_utils.CheckpointableSaver(
- new_root).restore(slots_path)
- no_slot_status = checkpointable_utils.CheckpointableSaver(
- new_root).restore(no_slots_path)
+ slot_status = new_root.restore(slots_path)
+ no_slot_status = new_root.restore(no_slots_path)
with self.assertRaises(AssertionError):
no_slot_status.assert_consumed()
new_root.var = checkpointable_utils.add_variable(
@@ -861,22 +853,19 @@
@test_util.run_in_graph_and_eager_modes
def testOverlappingRestores(self):
checkpoint_directory = self.get_temp_dir()
- save_root = tracking.AutoCheckpointable()
+ save_root = checkpointable_utils.Checkpoint()
save_root.dep = tracking.AutoCheckpointable()
save_root.dep.var = checkpointable_utils.add_variable(
save_root.dep, name="var", initializer=0.)
self.evaluate(state_ops.assign(save_root.dep.var, 12.))
- saver = checkpointable_utils.CheckpointableSaver(save_root)
- first_path = saver.save(os.path.join(checkpoint_directory, "first"))
+ first_path = save_root.save(os.path.join(checkpoint_directory, "first"))
self.evaluate(state_ops.assign(save_root.dep.var, 13.))
- second_path = saver.save(os.path.join(checkpoint_directory, "second"))
+ second_path = save_root.save(os.path.join(checkpoint_directory, "second"))
- first_root = tracking.AutoCheckpointable()
- second_root = tracking.AutoCheckpointable()
- first_status = checkpointable_utils.CheckpointableSaver(
- first_root).restore(first_path)
- second_status = checkpointable_utils.CheckpointableSaver(
- second_root).restore(second_path)
+ first_root = checkpointable_utils.Checkpoint()
+ second_root = checkpointable_utils.Checkpoint()
+ first_status = first_root.restore(first_path)
+ second_status = second_root.restore(second_path)
load_dep = tracking.AutoCheckpointable()
load_dep.var = checkpointable_utils.add_variable(
load_dep, name="var", shape=[])
@@ -891,12 +880,10 @@
# Try again with the order of the restore() reversed. The last restore
# determines the final value.
- first_root = tracking.AutoCheckpointable()
- second_root = tracking.AutoCheckpointable()
- second_status = checkpointable_utils.CheckpointableSaver(
- second_root).restore(second_path)
- first_status = checkpointable_utils.CheckpointableSaver(
- first_root).restore(first_path)
+ first_root = checkpointable_utils.Checkpoint()
+ second_root = checkpointable_utils.Checkpoint()
+ second_status = second_root.restore(second_path)
+ first_status = first_root.restore(first_path)
load_dep = tracking.AutoCheckpointable()
load_dep.var = checkpointable_utils.add_variable(
load_dep, name="var", shape=[])
@@ -913,7 +900,7 @@
def testAmbiguousLoad(self):
# Not OK to split one checkpoint object into two
checkpoint_directory = self.get_temp_dir()
- save_root = tracking.AutoCheckpointable()
+ save_root = checkpointable_utils.Checkpoint()
save_root.dep_one = tracking.AutoCheckpointable()
save_root.dep_two = tracking.AutoCheckpointable()
dep_three = tracking.AutoCheckpointable()
@@ -921,11 +908,9 @@
save_root.dep_two.dep_three = dep_three
checkpointable_utils.add_variable(dep_three, name="var", initializer=0.)
self.evaluate(checkpointable_utils.gather_initializers(save_root))
- save_path = checkpointable_utils.CheckpointableSaver(save_root).save(
- os.path.join(checkpoint_directory, "ckpt"))
- load_root = tracking.AutoCheckpointable()
- status = checkpointable_utils.CheckpointableSaver(load_root).restore(
- save_path)
+ save_path = save_root.save(os.path.join(checkpoint_directory, "ckpt"))
+ load_root = checkpointable_utils.Checkpoint()
+ status = load_root.restore(save_path)
load_root.dep_one = tracking.AutoCheckpointable()
load_root.dep_two = tracking.AutoCheckpointable()
load_root.dep_one.dep_three = tracking.AutoCheckpointable()
@@ -941,7 +926,7 @@
def testObjectsCombined(self):
# Currently fine to load two checkpoint objects into one Python object
checkpoint_directory = self.get_temp_dir()
- save_root = tracking.AutoCheckpointable()
+ save_root = checkpointable_utils.Checkpoint()
save_root.dep_one = tracking.AutoCheckpointable()
save_root.dep_two = tracking.AutoCheckpointable()
checkpointable_utils.add_variable(
@@ -949,16 +934,15 @@
checkpointable_utils.add_variable(
save_root.dep_two, name="var2", initializer=64., dtype=dtypes.float64)
self.evaluate(checkpointable_utils.gather_initializers(save_root))
- save_path = checkpointable_utils.CheckpointableSaver(save_root).save(
- os.path.join(checkpoint_directory, "ckpt"))
- load_root = tracking.AutoCheckpointable()
+ save_path = save_root.save(os.path.join(checkpoint_directory, "ckpt"))
+ load_root = checkpointable_utils.Checkpoint()
load_root.dep_one = tracking.AutoCheckpointable()
load_root.dep_two = load_root.dep_one
v1 = checkpointable_utils.add_variable(
load_root.dep_one, name="var1", shape=[], dtype=dtypes.float64)
v2 = checkpointable_utils.add_variable(
load_root.dep_one, name="var2", shape=[], dtype=dtypes.float64)
- status = checkpointable_utils.CheckpointableSaver(load_root).restore(
+ status = load_root.restore(
save_path).assert_consumed().assert_existing_objects_matched()
status.run_restore_ops()
self.assertEqual(32., self.evaluate(v1))
@@ -968,8 +952,8 @@
def testDependencyLoop(self):
# Note: this test creates garbage during eager execution because it
# purposefully creates a reference cycle.
- first = tracking.AutoCheckpointable()
- second = tracking.AutoCheckpointable()
+ first = checkpointable_utils.Checkpoint()
+ second = checkpointable_utils.Checkpoint()
first.second = second
second.first = first
first.v = checkpointable_utils.add_variable(
@@ -978,13 +962,11 @@
second, "v2", initializer=[1., 1., 2., 3.])
self.evaluate(checkpointable_utils.gather_initializers(first))
checkpoint_directory = self.get_temp_dir()
- save_path = checkpointable_utils.CheckpointableSaver(first).save(
- os.path.join(checkpoint_directory, "ckpt"))
+ save_path = first.save(os.path.join(checkpoint_directory, "ckpt"))
# Test deferred loading
- first_load = tracking.AutoCheckpointable()
- status = checkpointable_utils.CheckpointableSaver(
- first_load).restore(save_path)
+ first_load = checkpointable_utils.Checkpoint()
+ status = first_load.restore(save_path)
second_load = tracking.AutoCheckpointable()
first_load.second = second_load
second_load.first = first_load
@@ -1004,8 +986,7 @@
self.assertAllEqual([2., 7., 1.], self.evaluate(first_load.v))
self.evaluate(second_load.v.assign([2., 7., 1., 8.]))
self.assertAllEqual([2., 7., 1., 8.], self.evaluate(second_load.v))
- status = checkpointable_utils.CheckpointableSaver(first_load).restore(
- save_path).assert_consumed()
+ status = first_load.restore(save_path).assert_consumed()
status.run_restore_ops()
self.assertAllEqual([3., 1., 4.], self.evaluate(first_load.v))
self.assertAllEqual([1., 1., 2., 3.], self.evaluate(second_load.v))
@@ -1014,18 +995,16 @@
def testRestoreOnAssign(self):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
- first = tracking.AutoCheckpointable()
+ first = checkpointable_utils.Checkpoint()
first.var1 = variables_lib.Variable(0., name="outside_var")
first.var2 = variables_lib.Variable(0., name="blah")
self.evaluate(first.var1.assign(4.))
self.evaluate(first.var2.assign(8.))
- save_path = checkpointable_utils.CheckpointableSaver(first).save(
- checkpoint_prefix)
+ save_path = first.save(checkpoint_prefix)
- second = tracking.AutoCheckpointable()
+ second = checkpointable_utils.Checkpoint()
second.var2 = variables_lib.Variable(0., name="blah")
- status = checkpointable_utils.CheckpointableSaver(
- second).restore(save_path)
+ status = second.restore(save_path)
recreated_var1 = variables_lib.Variable(0., name="outside_var")
status.run_restore_ops()
self.assertEqual(8., self.evaluate(second.var2))
@@ -1042,17 +1021,16 @@
with graph.as_default(), self.session(graph):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
- obj = tracking.AutoCheckpointable()
+ obj = checkpointable_utils.Checkpoint()
obj.var = variables_lib.Variable(0., name="v")
obj.opt = adam.Adam(0.1)
variables = [obj.var]
gradients = [1.]
obj.opt.apply_gradients(zip(gradients, variables))
self.evaluate(checkpointable_utils.gather_initializers(obj))
- saver = checkpointable_utils.CheckpointableSaver(obj)
- saver.save(checkpoint_prefix)
+ obj.save(checkpoint_prefix)
graph.finalize()
- saver.save(checkpoint_prefix)
+ obj.save(checkpoint_prefix)
@test_util.run_in_graph_and_eager_modes
def testCheckpointState(self):
@@ -1132,18 +1110,17 @@
with graph.as_default(), self.session(graph):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
- obj = tracking.AutoCheckpointable()
+ obj = checkpointable_utils.Checkpoint()
obj.var = variables_lib.Variable(0., name="v")
obj.opt = adam.Adam(0.1)
variables = [obj.var]
gradients = [1.]
obj.opt.apply_gradients(zip(gradients, variables))
self.evaluate(checkpointable_utils.gather_initializers(obj))
- saver = checkpointable_utils.CheckpointableSaver(obj)
- save_path = saver.save(checkpoint_prefix)
- saver.restore(save_path)
+ save_path = obj.save(checkpoint_prefix)
+ obj.restore(save_path)
graph.finalize()
- saver.restore(save_path)
+ obj.restore(save_path)
@test_util.run_in_graph_and_eager_modes
def test_sequential(self):
@@ -1461,7 +1438,8 @@
self._set_sentinels(root)
with self.assertRaises(AssertionError):
self._check_sentinels(root)
- object_saver = checkpointable_utils.CheckpointableSaver(root)
+ object_saver = checkpointable_utils.CheckpointableSaver(
+ graph_view.ObjectGraphView(root))
self._set_sentinels(root)
status = object_saver.restore(save_path)
if context.executing_eagerly():
diff --git a/tensorflow/python/training/checkpointable/util_with_v1_optimizers_test.py b/tensorflow/python/training/checkpointable/util_with_v1_optimizers_test.py
index bd80fa6..d7158c0 100644
--- a/tensorflow/python/training/checkpointable/util_with_v1_optimizers_test.py
+++ b/tensorflow/python/training/checkpointable/util_with_v1_optimizers_test.py
@@ -43,6 +43,7 @@
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
+from tensorflow.python.training.checkpointable import graph_view
from tensorflow.python.training.checkpointable import tracking
from tensorflow.python.training.checkpointable import util as checkpointable_utils
@@ -100,9 +101,8 @@
self.evaluate(checkpointable_utils.gather_initializers(
root_checkpointable))
self.evaluate(train_op)
- named_variables, serialized_graph, _ = (
- checkpointable_utils._serialize_object_graph(
- root_checkpointable, saveables_cache=None))
+ named_variables, serialized_graph, _ = graph_view.ObjectGraphView(
+ root_checkpointable).serialize_object_graph()
expected_checkpoint_names = (
# Created in the root node, so no prefix.
"optimizer_step",
@@ -503,7 +503,7 @@
def testDeferredSlotRestoration(self):
checkpoint_directory = self.get_temp_dir()
- root = tracking.AutoCheckpointable()
+ root = checkpointable_utils.Checkpoint()
root.var = checkpointable_utils.add_variable(
root, name="var", initializer=0.)
optimizer = adam.AdamOptimizer(0.1)
@@ -518,21 +518,17 @@
checkpointable_utils.Checkpoint(root=root, optimizer=optimizer)))
self.evaluate(train_op)
self.evaluate(state_ops.assign(root.var, 12.))
- no_slots_path = checkpointable_utils.CheckpointableSaver(root).save(
- os.path.join(checkpoint_directory, "no_slots"))
+ no_slots_path = root.save(os.path.join(checkpoint_directory, "no_slots"))
root.optimizer = optimizer
self.evaluate(state_ops.assign(root.var, 13.))
self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var),
14.))
- slots_path = checkpointable_utils.CheckpointableSaver(root).save(
- os.path.join(checkpoint_directory, "with_slots"))
- new_root = tracking.AutoCheckpointable()
+ slots_path = root.save(os.path.join(checkpoint_directory, "with_slots"))
+ new_root = checkpointable_utils.Checkpoint()
# Load the slot-containing checkpoint (deferred), then immediately overwrite
# the non-slot variable (also deferred).
- slot_status = checkpointable_utils.CheckpointableSaver(
- new_root).restore(slots_path)
- no_slot_status = checkpointable_utils.CheckpointableSaver(
- new_root).restore(no_slots_path)
+ slot_status = new_root.restore(slots_path)
+ no_slot_status = new_root.restore(no_slots_path)
with self.assertRaises(AssertionError):
no_slot_status.assert_consumed()
new_root.var = checkpointable_utils.add_variable(
@@ -572,15 +568,14 @@
with graph.as_default(), self.session(graph):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
- obj = tracking.AutoCheckpointable()
+ obj = checkpointable_utils.Checkpoint()
obj.var = variable_scope.get_variable(name="v", initializer=0.)
obj.opt = adam.AdamOptimizer(0.1)
obj.opt.minimize(obj.var.read_value())
self.evaluate(checkpointable_utils.gather_initializers(obj))
- saver = checkpointable_utils.CheckpointableSaver(obj)
- saver.save(checkpoint_prefix)
+ obj.save(checkpoint_prefix)
before_ops = graph.get_operations()
- saver.save(checkpoint_prefix)
+ obj.save(checkpoint_prefix)
self.assertEqual(before_ops, graph.get_operations())
def testManyRestoresGraph(self):
@@ -590,16 +585,15 @@
with graph.as_default(), self.session(graph):
checkpoint_directory = self.get_temp_dir()
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
- obj = tracking.AutoCheckpointable()
+ obj = checkpointable_utils.Checkpoint()
obj.var = variable_scope.get_variable(name="v", initializer=0.)
obj.opt = adam.AdamOptimizer(0.1)
obj.opt.minimize(obj.var.read_value())
self.evaluate(checkpointable_utils.gather_initializers(obj))
- saver = checkpointable_utils.CheckpointableSaver(obj)
- save_path = saver.save(checkpoint_prefix)
- saver.restore(save_path)
+ save_path = obj.save(checkpoint_prefix)
+ obj.restore(save_path)
before_ops = graph.get_operations()
- saver.restore(save_path)
+ obj.restore(save_path)
self.assertEqual(before_ops, graph.get_operations())
def testMultipleGraphsNonSlotVariables(self):
@@ -869,7 +863,8 @@
self._set_sentinels(root)
with self.assertRaises(AssertionError):
self._check_sentinels(root)
- object_saver = checkpointable_utils.CheckpointableSaver(root)
+ object_saver = checkpointable_utils.CheckpointableSaver(
+ graph_view.ObjectGraphView(root))
self._set_sentinels(root)
status = object_saver.restore(save_path)
if context.executing_eagerly():
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py
index 41a42bd..1687898 100644
--- a/tensorflow/python/training/monitored_session.py
+++ b/tensorflow/python/training/monitored_session.py
@@ -41,6 +41,7 @@
from tensorflow.python.training import saver as training_saver
from tensorflow.python.training import session_manager as sm
from tensorflow.python.training import session_run_hook
+from tensorflow.python.training.checkpointable import graph_view
from tensorflow.python.training.checkpointable import util as checkpointable_util
from tensorflow.python.util import function_utils
from tensorflow.python.util.tf_export import tf_export
@@ -229,8 +230,8 @@
# pylint: enable=g-long-lambda
if isinstance(self._saver, checkpointable_util.Checkpoint):
self._saver = training_saver.Saver(
- var_list=checkpointable_util.CheckpointableSaver(
- self._saver).gather_objects(),
+ var_list=graph_view.ObjectGraphView(
+ self._saver).frozen_saveable_objects(),
sharded=True)
else:
self._saver.build()
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index 3742ebb..a98fcc2 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -24,7 +24,6 @@
import six
-from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribution_strategy_context as distribute_ctx
from tensorflow.python.distribute import reduce_util as ds_reduce_util
from tensorflow.python.eager import backprop
@@ -461,12 +460,6 @@
tape.watch(var_list)
loss_value = loss()
- # Scale loss if using a "mean" loss reduction and multiple replicas.
- # Have to be careful to call distribute_lib.get_loss_reduction()
- # *after* loss() is evaluated, so we know what loss reduction it uses.
- # TODO(josh11b): Test that we handle weight decay in a reasonable way.
- loss_value = self._scale_loss(loss_value)
-
if var_list is None:
var_list = tape.watched_variables()
# TODO(jhseu): Figure out why GradientTape's gradients don't require loss
@@ -481,9 +474,6 @@
"`loss` passed to Optimizer.compute_gradients should "
"be a function when eager execution is enabled.")
- # Scale loss if using a "mean" loss reduction and multiple replicas.
- loss = self._scale_loss(loss)
-
if gate_gradients not in [Optimizer.GATE_NONE, Optimizer.GATE_OP,
Optimizer.GATE_GRAPH]:
raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, "
@@ -518,14 +508,6 @@
if g is not None and v.dtype != dtypes.resource])
return grads_and_vars
- @staticmethod
- def _scale_loss(loss_value):
- if distribute_lib.get_loss_reduction() == ds_reduce_util.ReduceOp.MEAN:
- num_replicas = distribute_ctx.get_strategy().num_replicas_in_sync
- if num_replicas > 1:
- loss_value *= (1. / num_replicas)
- return loss_value
-
def apply_gradients(self, grads_and_vars, global_step=None, name=None):
"""Apply gradients to variables.
diff --git a/tensorflow/python/training/optimizer_test.py b/tensorflow/python/training/optimizer_test.py
index e175b5a..ac831cb 100644
--- a/tensorflow/python/training/optimizer_test.py
+++ b/tensorflow/python/training/optimizer_test.py
@@ -24,7 +24,7 @@
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
-from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import gradients_util
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
@@ -75,7 +75,7 @@
opt_op = sgd_op.minimize(
cost,
global_step, [var0, var1],
- aggregation_method=gradients_impl.AggregationMethod.
+ aggregation_method=gradients_util.AggregationMethod.
EXPERIMENTAL_ACCUMULATE_N)
variables.global_variables_initializer().run()
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index dec23c5..dfc43ee 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -3012,13 +3012,13 @@
save_graph = ops_lib.Graph()
with save_graph.as_default(), self.session(graph=save_graph) as sess:
root = self._initialized_model()
- object_saver = checkpointable_utils.CheckpointableSaver(root)
+ object_saver = checkpointable_utils.Checkpoint(root=root)
save_path = object_saver.save(file_prefix=checkpoint_prefix)
# An incompatible object-based checkpoint to check error messages
var = resource_variable_ops.ResourceVariable(1., name="a")
self.evaluate(var.initializer)
- second_saver = checkpointable_utils.CheckpointableSaver(var)
+ second_saver = checkpointable_utils.Checkpoint(v=var)
second_path = second_saver.save(file_prefix=os.path.join(
checkpoint_directory, "second"))
@@ -3046,7 +3046,7 @@
save_graph = ops_lib.Graph()
with save_graph.as_default(), self.session(graph=save_graph):
root = self._initialized_model()
- object_saver = checkpointable_utils.CheckpointableSaver(root)
+ object_saver = checkpointable_utils.Checkpoint(root=root)
save_path = object_saver.save(file_prefix=checkpoint_prefix)
with context.eager_mode():
diff --git a/tensorflow/python/util/example_parser_configuration.py b/tensorflow/python/util/example_parser_configuration.py
index e3fdcf9..dc8937a 100644
--- a/tensorflow/python/util/example_parser_configuration.py
+++ b/tensorflow/python/util/example_parser_configuration.py
@@ -101,7 +101,7 @@
fixed_config.shape.CopyFrom(
tensor_shape.TensorShape(dense_shapes[i]).as_proto())
- fixed_config.dtype = int(dense_types[i])
+ fixed_config.dtype = dense_types[i].as_datatype_enum
# Get the output tensor name.
fixed_config.values_output_tensor_name = parse_example_op.outputs[
dense_values_start + i].name
@@ -111,7 +111,7 @@
key = fetched[sparse_keys_start + i]
feature_config = config.feature_map[key]
var_len_feature = feature_config.var_len_feature
- var_len_feature.dtype = int(sparse_types[i])
+ var_len_feature.dtype = sparse_types[i].as_datatype_enum
var_len_feature.indices_output_tensor_name = parse_example_op.outputs[
sparse_indices_start + i].name
var_len_feature.values_output_tensor_name = parse_example_op.outputs[
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py
index a43ec48..2435694 100644
--- a/tensorflow/python/util/nest.py
+++ b/tensorflow/python/util/nest.py
@@ -761,6 +761,100 @@
return list(v for _, v in _yield_flat_up_to(shallow_tree, input_tree))
+def flatten_with_tuple_paths_up_to(shallow_tree, input_tree, check_types=True):
+ """Flattens `input_tree` up to `shallow_tree`.
+
+ Any further depth in structure in `input_tree` is retained as elements in the
+ partially flattened output.
+
+ Returns a list of (path, value) pairs, where value a leaf node in the
+ flattened tree, and path is the tuple path of that leaf in input_tree.
+
+ If `shallow_tree` and `input_tree` are not sequences, this returns a
+ single-element list: `[((), input_tree)]`.
+
+ Use Case:
+
+ Sometimes we may wish to partially flatten a nested sequence, retaining some
+ of the nested structure. We achieve this by specifying a shallow structure,
+ `shallow_tree`, we wish to flatten up to.
+
+ The input, `input_tree`, can be thought of as having the same structure as
+ `shallow_tree`, but with leaf nodes that are themselves tree structures.
+
+ Examples:
+
+ ```python
+ input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
+ shallow_tree = [[True, True], [False, True]]
+
+ flattened_input_tree = flatten_with_tuple_paths_up_to(shallow_tree,
+ input_tree)
+ flattened_shallow_tree = flatten_with_tuple_paths_up_to(shallow_tree,
+ shallow_tree)
+
+ # Output is:
+ # [((0, 0), [2, 2]),
+ # ((0, 1), [3, 3]),
+ # ((1, 0), [4, 9]),
+ # ((1, 1), [5, 5])]
+ #
+ # [((0, 0), True),
+ # ((0, 1), True),
+ # ((1, 0), False),
+ # ((1, 1), True)]
+ ```
+
+ ```python
+ input_tree = [[('a', 1), [('b', 2), [('c', 3), [('d', 4)]]]]]
+ shallow_tree = [['level_1', ['level_2', ['level_3', ['level_4']]]]]
+
+ input_tree_flattened_as_shallow_tree = flatten_up_to(shallow_tree, input_tree)
+ input_tree_flattened = flatten(input_tree)
+
+ # Output is:
+ # [((0, 0), ('a', 1)),
+ # ((0, 1, 0), ('b', 2)),
+ # ((0, 1, 1, 0), ('c', 3)),
+ # ((0, 1, 1, 1), ('d', 4))]
+ # ['a', 1, 'b', 2, 'c', 3, 'd', 4]
+ ```
+
+ Non-Sequence Edge Cases:
+
+ ```python
+ flatten_with_tuple_paths_up_to(0, 0) # Output: [(), 0]
+
+ flatten_with_tuple_paths_up_to(0, [0, 1, 2]) # Output: [(), [0, 1, 2]]
+
+ flatten_with_tuple_paths_up_to([0, 1, 2], 0) # Output: TypeError
+
+ flatten_with_tuple_paths_up_to([0, 1, 2], [0, 1, 2])
+ # Output: [((0,) 0), ((1,), 1), ((2,), 2)]
+ ```
+
+ Args:
+ shallow_tree: a possibly pruned structure of input_tree.
+ input_tree: an arbitrarily nested structure or a scalar object.
+ Note, numpy arrays are considered scalars.
+ check_types: bool. If True, check that each node in shallow_tree has the
+ same type as the corresponding node in input_tree.
+
+ Returns:
+ A Python list, the partially flattened version of `input_tree` according to
+ the structure of `shallow_tree`.
+
+ Raises:
+ TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
+ TypeError: If the sequence types of `shallow_tree` are different from
+ `input_tree`.
+ ValueError: If the sequence lengths of `shallow_tree` are different from
+ `input_tree`.
+ """
+ assert_shallow_structure(shallow_tree, input_tree, check_types=check_types)
+ return list(_yield_flat_up_to(shallow_tree, input_tree))
+
+
def map_structure_up_to(shallow_tree, func, *inputs, **kwargs):
"""Applies a function or op to a number of partially flattened inputs.
diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py
index 71034ff..ec559bd 100644
--- a/tensorflow/python/util/nest_test.py
+++ b/tensorflow/python/util/nest_test.py
@@ -686,6 +686,244 @@
flattened_shallow_tree = nest.flatten_up_to(shallow_tree, shallow_tree)
self.assertEqual(flattened_shallow_tree, shallow_tree)
+ def testFlattenWithTuplePathsUpTo(self):
+ def get_paths_and_values(shallow_tree, input_tree):
+ path_value_pairs = nest.flatten_with_tuple_paths_up_to(shallow_tree,
+ input_tree)
+ paths = [p for p, _ in path_value_pairs]
+ values = [v for _, v in path_value_pairs]
+ return paths, values
+
+ # Shallow tree ends at scalar.
+ input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
+ shallow_tree = [[True, True], [False, True]]
+ (flattened_input_tree_paths,
+ flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
+ (flattened_shallow_tree_paths,
+ flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
+ self.assertEqual(flattened_input_tree_paths,
+ [(0, 0), (0, 1), (1, 0), (1, 1)])
+ self.assertEqual(flattened_input_tree, [[2, 2], [3, 3], [4, 9], [5, 5]])
+ self.assertEqual(flattened_shallow_tree_paths,
+ [(0, 0), (0, 1), (1, 0), (1, 1)])
+ self.assertEqual(flattened_shallow_tree, [True, True, False, True])
+
+ # Shallow tree ends at string.
+ input_tree = [[("a", 1), [("b", 2), [("c", 3), [("d", 4)]]]]]
+ shallow_tree = [["level_1", ["level_2", ["level_3", ["level_4"]]]]]
+ (input_tree_flattened_as_shallow_tree_paths,
+ input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree,
+ input_tree)
+ input_tree_flattened_paths = [p for p, _ in
+ nest.flatten_with_tuple_paths(input_tree)]
+ input_tree_flattened = nest.flatten(input_tree)
+ self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
+ [(0, 0), (0, 1, 0), (0, 1, 1, 0), (0, 1, 1, 1, 0)])
+ self.assertEqual(input_tree_flattened_as_shallow_tree,
+ [("a", 1), ("b", 2), ("c", 3), ("d", 4)])
+
+ self.assertEqual(input_tree_flattened_paths,
+ [(0, 0, 0), (0, 0, 1),
+ (0, 1, 0, 0), (0, 1, 0, 1),
+ (0, 1, 1, 0, 0), (0, 1, 1, 0, 1),
+ (0, 1, 1, 1, 0, 0), (0, 1, 1, 1, 0, 1)])
+ self.assertEqual(input_tree_flattened, ["a", 1, "b", 2, "c", 3, "d", 4])
+
+ # Make sure dicts are correctly flattened, yielding values, not keys.
+ input_tree = {"a": 1, "b": {"c": 2}, "d": [3, (4, 5)]}
+ shallow_tree = {"a": 0, "b": 0, "d": [0, 0]}
+ (input_tree_flattened_as_shallow_tree_paths,
+ input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree,
+ input_tree)
+ self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
+ [("a",), ("b",), ("d", 0), ("d", 1)])
+ self.assertEqual(input_tree_flattened_as_shallow_tree,
+ [1, {"c": 2}, 3, (4, 5)])
+
+ # Namedtuples.
+ ab_tuple = collections.namedtuple("ab_tuple", "a, b")
+ input_tree = ab_tuple(a=[0, 1], b=2)
+ shallow_tree = ab_tuple(a=0, b=1)
+ (input_tree_flattened_as_shallow_tree_paths,
+ input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree,
+ input_tree)
+ self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
+ [("a",), ("b",)])
+ self.assertEqual(input_tree_flattened_as_shallow_tree,
+ [[0, 1], 2])
+
+ # Nested dicts, OrderedDicts and namedtuples.
+ input_tree = collections.OrderedDict(
+ [("a", ab_tuple(a=[0, {"b": 1}], b=2)),
+ ("c", {"d": 3, "e": collections.OrderedDict([("f", 4)])})])
+ shallow_tree = input_tree
+ (input_tree_flattened_as_shallow_tree_paths,
+ input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree,
+ input_tree)
+ self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
+ [("a", "a", 0),
+ ("a", "a", 1, "b"),
+ ("a", "b"),
+ ("c", "d"),
+ ("c", "e", "f")])
+ self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4])
+ shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})])
+ (input_tree_flattened_as_shallow_tree_paths,
+ input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree,
+ input_tree)
+ self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
+ [("a",),
+ ("c", "d"),
+ ("c", "e")])
+ self.assertEqual(input_tree_flattened_as_shallow_tree,
+ [ab_tuple(a=[0, {"b": 1}], b=2),
+ 3,
+ collections.OrderedDict([("f", 4)])])
+ shallow_tree = collections.OrderedDict([("a", 0), ("c", 0)])
+ (input_tree_flattened_as_shallow_tree_paths,
+ input_tree_flattened_as_shallow_tree) = get_paths_and_values(shallow_tree,
+ input_tree)
+ self.assertEqual(input_tree_flattened_as_shallow_tree_paths,
+ [("a",), ("c",)])
+ self.assertEqual(input_tree_flattened_as_shallow_tree,
+ [ab_tuple(a=[0, {"b": 1}], b=2),
+ {"d": 3, "e": collections.OrderedDict([("f", 4)])}])
+
+ ## Shallow non-list edge-case.
+ # Using iterable elements.
+ input_tree = ["input_tree"]
+ shallow_tree = "shallow_tree"
+ (flattened_input_tree_paths,
+ flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
+ (flattened_shallow_tree_paths,
+ flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
+ self.assertEqual(flattened_input_tree_paths, [()])
+ self.assertEqual(flattened_input_tree, [input_tree])
+ self.assertEqual(flattened_shallow_tree_paths, [()])
+ self.assertEqual(flattened_shallow_tree, [shallow_tree])
+
+ input_tree = ["input_tree_0", "input_tree_1"]
+ shallow_tree = "shallow_tree"
+ (flattened_input_tree_paths,
+ flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
+ (flattened_shallow_tree_paths,
+ flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
+ self.assertEqual(flattened_input_tree_paths, [()])
+ self.assertEqual(flattened_input_tree, [input_tree])
+ self.assertEqual(flattened_shallow_tree_paths, [()])
+ self.assertEqual(flattened_shallow_tree, [shallow_tree])
+
+ # Test case where len(shallow_tree) < len(input_tree)
+ input_tree = {"a": "A", "b": "B", "c": "C"}
+ shallow_tree = {"a": 1, "c": 2}
+ (flattened_input_tree_paths,
+ flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
+ (flattened_shallow_tree_paths,
+ flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
+ self.assertEqual(flattened_input_tree_paths, [("a",), ("c",)])
+ self.assertEqual(flattened_input_tree, ["A", "C"])
+ self.assertEqual(flattened_shallow_tree_paths, [("a",), ("c",)])
+ self.assertEqual(flattened_shallow_tree, [1, 2])
+
+ # Using non-iterable elements.
+ input_tree = [0]
+ shallow_tree = 9
+ (flattened_input_tree_paths,
+ flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
+ (flattened_shallow_tree_paths,
+ flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
+ self.assertEqual(flattened_input_tree_paths, [()])
+ self.assertEqual(flattened_input_tree, [input_tree])
+ self.assertEqual(flattened_shallow_tree_paths, [()])
+ self.assertEqual(flattened_shallow_tree, [shallow_tree])
+
+ input_tree = [0, 1]
+ shallow_tree = 9
+ (flattened_input_tree_paths,
+ flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
+ (flattened_shallow_tree_paths,
+ flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
+ self.assertEqual(flattened_input_tree_paths, [()])
+ self.assertEqual(flattened_input_tree, [input_tree])
+ self.assertEqual(flattened_shallow_tree_paths, [()])
+ self.assertEqual(flattened_shallow_tree, [shallow_tree])
+
+ ## Both non-list edge-case.
+ # Using iterable elements.
+ input_tree = "input_tree"
+ shallow_tree = "shallow_tree"
+ (flattened_input_tree_paths,
+ flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
+ (flattened_shallow_tree_paths,
+ flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
+ self.assertEqual(flattened_input_tree_paths, [()])
+ self.assertEqual(flattened_input_tree, [input_tree])
+ self.assertEqual(flattened_shallow_tree_paths, [()])
+ self.assertEqual(flattened_shallow_tree, [shallow_tree])
+
+ # Using non-iterable elements.
+ input_tree = 0
+ shallow_tree = 0
+ (flattened_input_tree_paths,
+ flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
+ (flattened_shallow_tree_paths,
+ flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
+ self.assertEqual(flattened_input_tree_paths, [()])
+ self.assertEqual(flattened_input_tree, [input_tree])
+ self.assertEqual(flattened_shallow_tree_paths, [()])
+ self.assertEqual(flattened_shallow_tree, [shallow_tree])
+
+ ## Input non-list edge-case.
+ # Using iterable elements.
+ input_tree = "input_tree"
+ shallow_tree = ["shallow_tree"]
+ with self.assertRaisesWithLiteralMatch(
+ TypeError,
+ nest._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(type(input_tree))):
+ (flattened_input_tree_paths,
+ flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
+ (flattened_shallow_tree_paths,
+ flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
+ self.assertEqual(flattened_shallow_tree_paths, [(0,)])
+ self.assertEqual(flattened_shallow_tree, shallow_tree)
+
+ input_tree = "input_tree"
+ shallow_tree = ["shallow_tree_9", "shallow_tree_8"]
+ with self.assertRaisesWithLiteralMatch(
+ TypeError,
+ nest._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(type(input_tree))):
+ (flattened_input_tree_paths,
+ flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
+ (flattened_shallow_tree_paths,
+ flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
+ self.assertEqual(flattened_shallow_tree_paths, [(0,), (1,)])
+ self.assertEqual(flattened_shallow_tree, shallow_tree)
+
+ # Using non-iterable elements.
+ input_tree = 0
+ shallow_tree = [9]
+ with self.assertRaisesWithLiteralMatch(
+ TypeError,
+ nest._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(type(input_tree))):
+ (flattened_input_tree_paths,
+ flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
+ (flattened_shallow_tree_paths,
+ flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
+ self.assertEqual(flattened_shallow_tree_paths, [(0,)])
+ self.assertEqual(flattened_shallow_tree, shallow_tree)
+
+ input_tree = 0
+ shallow_tree = [9, 8]
+ with self.assertRaisesWithLiteralMatch(
+ TypeError,
+ nest._IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ.format(type(input_tree))):
+ (flattened_input_tree_paths,
+ flattened_input_tree) = get_paths_and_values(shallow_tree, input_tree)
+ (flattened_shallow_tree_paths,
+ flattened_shallow_tree) = get_paths_and_values(shallow_tree, shallow_tree)
+ self.assertEqual(flattened_shallow_tree_paths, [(0,), (1,)])
+ self.assertEqual(flattened_shallow_tree, shallow_tree)
+
def testMapStructureUpTo(self):
# Named tuples.
ab_tuple = collections.namedtuple("ab_tuple", "a, b")
diff --git a/tensorflow/stream_executor/cuda/BUILD b/tensorflow/stream_executor/cuda/BUILD
index 9dc3be4..03dd92e 100644
--- a/tensorflow/stream_executor/cuda/BUILD
+++ b/tensorflow/stream_executor/cuda/BUILD
@@ -45,17 +45,19 @@
srcs = if_cuda_is_configured(["cuda_platform.cc"]),
hdrs = if_cuda_is_configured(["cuda_platform.h"]),
visibility = ["//visibility:public"],
- deps = if_cuda_is_configured([
- ":cuda_driver",
- ":cuda_gpu_executor",
- ":cuda_platform_id",
- "//tensorflow/stream_executor", # buildcleaner: keep
- "//tensorflow/stream_executor:executor_cache",
- "//tensorflow/stream_executor:multi_platform_manager",
- "//tensorflow/stream_executor:stream_executor_pimpl_header",
- "//tensorflow/stream_executor/lib",
- "//tensorflow/stream_executor/platform",
- ] + tf_additional_cuda_platform_deps()),
+ deps = if_cuda_is_configured(
+ [
+ ":cuda_driver",
+ ":cuda_gpu_executor",
+ ":cuda_platform_id",
+ "//tensorflow/stream_executor", # buildcleaner: keep
+ "//tensorflow/stream_executor:executor_cache",
+ "//tensorflow/stream_executor:multi_platform_manager",
+ "//tensorflow/stream_executor:stream_executor_pimpl_header",
+ "//tensorflow/stream_executor/lib",
+ "//tensorflow/stream_executor/platform",
+ ],
+ ) + tf_additional_cuda_platform_deps(),
alwayslink = True, # Registers itself with the MultiPlatformManager.
)
@@ -90,7 +92,36 @@
"//tensorflow/stream_executor/lib",
"//tensorflow/stream_executor/platform",
"//tensorflow/stream_executor/platform:dso_loader",
- ] + tf_additional_cuda_driver_deps()),
+ ] + tf_additional_cuda_driver_deps()) + select({
+ # include dynamic loading implementation only when if_cuda_is_configured and build dynamically
+ "//tensorflow:using_cuda_nvcc_with_dynamic_build": ["cudart_stub"],
+ "//tensorflow:using_cuda_clang_with_dynamic_build": ["cudart_stub"],
+ "//conditions:default": ["//tensorflow/core:cuda"],
+ }),
+)
+
+cc_library(
+ name = "cudart_stub",
+ srcs = select({
+ # include dynamic loading implementation only when if_cuda_is_configured and build dynamically
+ "//tensorflow:using_cuda_nvcc_with_dynamic_build": ["cudart_stub.cc"],
+ "//tensorflow:using_cuda_clang_with_dynamic_build": ["cudart_stub.cc"],
+ "//conditions:default": [],
+ }),
+ visibility = ["//visibility:public"],
+ deps = select({
+ "//tensorflow:using_cuda_nvcc_with_dynamic_build": [
+ "@local_config_cuda//cuda:cuda_headers",
+ "//tensorflow/stream_executor/lib",
+ "//tensorflow/stream_executor/platform:dso_loader",
+ ],
+ "//tensorflow:using_cuda_clang_with_dynamic_build": [
+ "@local_config_cuda//cuda:cuda_headers",
+ "//tensorflow/stream_executor/lib",
+ "//tensorflow/stream_executor/platform:dso_loader",
+ ],
+ "//conditions:default": [],
+ }),
)
# The activation library is tightly coupled to the executor library.
@@ -221,7 +252,9 @@
"//tensorflow/stream_executor/lib",
"//tensorflow/stream_executor/platform",
"//tensorflow/stream_executor/platform:dso_loader",
- ] + tf_additional_cudnn_plugin_deps() + if_static(["@local_config_cuda//cuda:cudnn"])),
+ ]) + tf_additional_cudnn_plugin_deps() + if_cuda_is_configured(if_static([
+ "@local_config_cuda//cuda:cudnn",
+ ])),
alwayslink = True,
)
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 06739e8..bae71b4 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -784,7 +784,6 @@
}
// A helper function to decide whether to enable deterministic functionality.
-// TODO(pr/24355): Support all cuDNN functionality (currently only convolution).
bool RequireDeterminism() {
static bool is_enabled = [] {
bool is_enabled = false;
@@ -887,10 +886,13 @@
std::transform(shape64.cbegin(), shape64.cend(), shape.begin(),
&CheckedNarrowing<int64, int>);
bool propagate_nans = pooling_descriptor.propagate_nans();
+ auto cudnn_max_pooling_mode = RequireDeterminism()
+ ? CUDNN_POOLING_MAX_DETERMINISTIC
+ : CUDNN_POOLING_MAX;
CHECK_CUDNN_OK(cudnnSetPoolingNdDescriptor(
handle_.get(),
(pooling_descriptor.mode() == dnn::PoolingMode::kMaximum
- ? CUDNN_POOLING_MAX
+ ? cudnn_max_pooling_mode
: CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING),
propagate_nans ? CUDNN_PROPAGATE_NAN : CUDNN_NOT_PROPAGATE_NAN, nd,
shape.data(), padding.data(), strides.data()));
@@ -4130,13 +4132,6 @@
return IsStatusOk(status, /*report_error=*/true);
}
-bool CudnnSupport::DoNormalize(
- Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
- const DeviceMemory<float>& input_data, DeviceMemory<float>* output_data) {
- LOG(FATAL) << "not yet implemented"; // TODO(leary)
- return false;
-}
-
bool CudnnSupport::DoNormalizeWithDimensions(
Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
const dnn::BatchDescriptor& dimensions,
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h
index 80c7c8a..d8a8ddf 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.h
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.h
@@ -484,11 +484,6 @@
DeviceMemory<Eigen::half>* output_diff_data,
ScratchAllocator* workspace_allocator) override;
- bool DoNormalize(Stream* stream,
- const dnn::NormalizeDescriptor& normalize_descriptor,
- const DeviceMemory<float>& input_data,
- DeviceMemory<float>* output_data) override;
-
bool DoNormalizeWithDimensions(
Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
const dnn::BatchDescriptor& dimensions,
diff --git a/tensorflow/stream_executor/cuda/cuda_driver.cc b/tensorflow/stream_executor/cuda/cuda_driver.cc
index 34ba7c5..080c26f 100644
--- a/tensorflow/stream_executor/cuda/cuda_driver.cc
+++ b/tensorflow/stream_executor/cuda/cuda_driver.cc
@@ -24,6 +24,7 @@
#include "absl/base/casts.h"
#include "absl/container/inlined_vector.h"
#include "absl/strings/str_cat.h"
+#include "cuda/include/cuda_runtime_api.h"
#include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
#include "tensorflow/stream_executor/cuda/cuda_driver_wrapper.h"
#include "tensorflow/stream_executor/lib/env.h"
@@ -107,11 +108,11 @@
// Formats CUresult to output prettified values into a log stream.
string ToString(CUresult result) {
- const char *error_name;
+ const char* error_name;
if (tensorflow::wrap::cuGetErrorName(result, &error_name)) {
return absl::StrCat("UNKNOWN ERROR (", static_cast<int>(result), ")");
}
- const char *error_string;
+ const char* error_string;
if (tensorflow::wrap::cuGetErrorString(result, &error_string)) {
return error_name;
}
@@ -139,14 +140,14 @@
// thread::ThreadPool on some platforms), we run certain routines in this pool
// and wait for completion.
static mutex driver_executor_threadpool_mu(LINKER_INITIALIZED);
-static port::ThreadPool *InitializeDriverExecutor() {
+static port::ThreadPool* InitializeDriverExecutor() {
return new port::ThreadPool(port::Env::Default(), port::ThreadOptions(),
"cuda_driver", 1);
}
-port::ThreadPool *GetDriverExecutor() {
+port::ThreadPool* GetDriverExecutor() {
mutex_lock lock(driver_executor_threadpool_mu);
- static port::ThreadPool *thread_pool = InitializeDriverExecutor();
+ static port::ThreadPool* thread_pool = InitializeDriverExecutor();
return thread_pool;
}
@@ -165,12 +166,30 @@
namespace {
+template <typename PtrT>
+bool PointerIsValid(const PtrT ptr) {
+ // Checks that the pointer is to a location on the device it purports to be.
+ // PtrT is one of CUdeviceptr or void*. If it's a CUdeviceptr, then
+ // cudaPointerGetAttributes should not fail, and return a memoryType of
+ // cudaMemoryTypeDevice.
+
+ bool is_host_ptr = !std::is_same<PtrT, CUdeviceptr>::value;
+ cudaPointerAttributes attributes;
+ cudaError_t err =
+ cudaPointerGetAttributes(&attributes, reinterpret_cast<const void*>(ptr));
+ // If we failed, reset cuda error status to avoid poisoning cuda streams.
+ if (err != cudaSuccess) cudaGetLastError();
+ bool points_to_host_memory = (err == cudaErrorInvalidValue ||
+ attributes.memoryType != cudaMemoryTypeDevice);
+ return (is_host_ptr == points_to_host_memory);
+}
+
// Call cuCtxtSynchronize and crash if it doesn't succeed.
void SynchronizeOrDie() {
auto res = tensorflow::wrap::cuCtxSynchronize();
if (res != CUDA_SUCCESS) {
- LOG(FATAL) << "Synchronize found "
- << ToString(res) << " :: " << port::CurrentStackTrace();
+ LOG(FATAL) << "Synchronize found " << ToString(res)
+ << " :: " << port::CurrentStackTrace();
}
}
@@ -284,7 +303,6 @@
: "false";
}
-
// Actually performs the work of CUDA initialization. Wrapped up in one-time
// execution guard.
static port::Status InternalInit() {
@@ -312,7 +330,7 @@
// called once, but GpuDriver::Init may be called many times.
static port::Status init_retval;
static bool set = false;
- static mutex *init_mu = new mutex;
+ static mutex* init_mu = new mutex;
mutex_lock lock(*init_mu);
if (!set) {
@@ -351,8 +369,8 @@
return true;
}
-bool DeviceOptionsToContextFlags(const DeviceOptions &device_options,
- int *flags) {
+bool DeviceOptionsToContextFlags(const DeviceOptions& device_options,
+ int* flags) {
static_assert(DeviceOptions::kMask == 0xf,
"needs update for new device options");
@@ -580,7 +598,7 @@
GetDriverExecutor()->Schedule([context, ptx_contents, module, &ret,
¬ification]() {
ScopedActivateContext activation(context);
- void *ptx_data = const_cast<char *>(ptx_contents);
+ void* ptx_data = const_cast<char*>(ptx_contents);
static const unsigned int kLogBufferBytesLimit = 1024;
unsigned int error_log_buffer_bytes = kLogBufferBytesLimit;
unsigned int info_log_buffer_bytes = kLogBufferBytesLimit;
@@ -593,12 +611,12 @@
CU_JIT_INFO_LOG_BUFFER, CU_JIT_LOG_VERBOSE};
// Note that the driver API wants the contents of this values to be stored
// in an array of void*s, so we coerce them accordingly.
- void *option_values[] = {
- absl::bit_cast<void *>(uintptr_t(error_log_buffer_bytes)),
- absl::bit_cast<void *>(error_log_buffer.data()),
- absl::bit_cast<void *>(uintptr_t(info_log_buffer_bytes)),
- absl::bit_cast<void *>(info_log_buffer.data()),
- absl::bit_cast<void *>(uintptr_t(log_verbose))};
+ void* option_values[] = {
+ absl::bit_cast<void*>(uintptr_t(error_log_buffer_bytes)),
+ absl::bit_cast<void*>(error_log_buffer.data()),
+ absl::bit_cast<void*>(uintptr_t(info_log_buffer_bytes)),
+ absl::bit_cast<void*>(info_log_buffer.data()),
+ absl::bit_cast<void*>(uintptr_t(log_verbose))};
CHECK(TF_ARRAYSIZE(options) == TF_ARRAYSIZE(option_values));
CUresult res;
@@ -622,8 +640,8 @@
LOG(ERROR) << "failed to load PTX text as a module: " << ToString(res);
// As a precaution for null termination of the API-provided value, ensure
// that at least the last byte is null.
- error_log_buffer[error_log_buffer_bytes ?
- error_log_buffer_bytes - 1 : 0] = '\0';
+ error_log_buffer[error_log_buffer_bytes ? error_log_buffer_bytes - 1
+ : 0] = '\0';
LOG(ERROR) << "error log buffer (" << error_log_buffer_bytes
<< " bytes): " << error_log_buffer.data();
ret = false;
@@ -828,7 +846,7 @@
<< " bytes) from device: " << ToString(res);
return nullptr;
}
- void *ptr = reinterpret_cast<void *>(result);
+ void* ptr = reinterpret_cast<void*>(result);
VLOG(2) << "allocated " << ptr << " for context " << context->context()
<< " of " << bytes << " bytes";
return ptr;
@@ -860,7 +878,7 @@
<< " bytes unified memory; result: " << ToString(res);
return nullptr;
}
- void *ptr = reinterpret_cast<void *>(result);
+ void* ptr = reinterpret_cast<void*>(result);
VLOG(2) << "allocated " << ptr << " for context " << context->context()
<< " of " << bytes << " bytes in unified memory";
return ptr;
@@ -882,7 +900,7 @@
/* static */ void* GpuDriver::HostAllocate(GpuContext* context, uint64 bytes) {
ScopedActivateContext activation(context);
- void *host_mem = nullptr;
+ void* host_mem = nullptr;
// "Portable" memory is visible to all CUDA contexts. Safe for our use model.
CUresult res = tensorflow::wrap::cuMemHostAlloc(&host_mem, bytes,
CU_MEMHOSTALLOC_PORTABLE);
@@ -1074,13 +1092,19 @@
CUdeviceptr gpu_src,
uint64 size) {
ScopedActivateContext activation(context);
+ if (size > 0) {
+ CHECK(PointerIsValid(gpu_src))
+ << "Source pointer is not actually on GPU: " << gpu_src;
+ CHECK(PointerIsValid(host_dst))
+ << "Destination pointer is not actually on CPU: " << host_dst;
+ }
CUresult res = tensorflow::wrap::cuMemcpyDtoH(host_dst, gpu_src, size);
if (res != CUDA_SUCCESS) {
return port::InternalError(
port::Printf("failed to synchronous memcpy from device to host: %s; "
"host dst: %p; GPU src: %p; size: %llu=0x%llx",
ToString(res).c_str(), host_dst,
- absl::bit_cast<void *>(gpu_src), size, size));
+ absl::bit_cast<void*>(gpu_src), size, size));
}
VLOG(2) << "successfully sync memcpy'd d2h of " << size << " bytes to "
<< host_dst;
@@ -1092,12 +1116,18 @@
const void* host_src,
uint64 size) {
ScopedActivateContext activation(context);
+ if (size > 0) {
+ CHECK(PointerIsValid(host_src))
+ << "Source pointer is not actually on CPU: " << host_src;
+ CHECK(PointerIsValid(gpu_dst))
+ << "Destination pointer is not actually on GPU: " << gpu_dst;
+ }
CUresult res = tensorflow::wrap::cuMemcpyHtoD(gpu_dst, host_src, size);
if (res != CUDA_SUCCESS) {
return port::InternalError(port::Printf(
"failed to synchronous memcpy from host to device: %s; GPU dst: %p;"
" host src: %p; size: %llu=0x%llx",
- ToString(res).c_str(), absl::bit_cast<void *>(gpu_dst), host_src, size,
+ ToString(res).c_str(), absl::bit_cast<void*>(gpu_dst), host_src, size,
size));
}
VLOG(2) << "successfully enqueued sync memcpy h2d of " << size << " bytes";
@@ -1109,13 +1139,19 @@
CUdeviceptr gpu_src,
uint64 size) {
ScopedActivateContext activation(context);
+ if (size > 0) {
+ CHECK(PointerIsValid(gpu_src))
+ << "Source pointer is not actually on GPU: " << gpu_src;
+ CHECK(PointerIsValid(gpu_dst))
+ << "Destination pointer is not actually on GPU: " << gpu_dst;
+ }
CUresult res = tensorflow::wrap::cuMemcpyDtoD(gpu_dst, gpu_src, size);
if (res != CUDA_SUCCESS) {
return port::InternalError(port::Printf(
"failed to synchronous memcpy from host to device: %s; GPU dst: %p; "
"GPU src: %p; size: %llu=0x%llx",
- ToString(res).c_str(), absl::bit_cast<void *>(gpu_dst),
- absl::bit_cast<void *>(gpu_src), size, size));
+ ToString(res).c_str(), absl::bit_cast<void*>(gpu_dst),
+ absl::bit_cast<void*>(gpu_src), size, size));
}
VLOG(2) << "successfully sync memcpy'd d2d of " << size << " bytes";
return port::Status::OK();
@@ -1127,18 +1163,24 @@
uint64 size,
CUstream stream) {
ScopedActivateContext activation(context);
+ if (size > 0) {
+ CHECK(PointerIsValid(gpu_src))
+ << "Source pointer is not actually on GPU: " << gpu_src;
+ CHECK(PointerIsValid(host_dst))
+ << "Destination pointer is not actually on CPU: " << host_dst;
+ }
CUresult res =
tensorflow::wrap::cuMemcpyDtoHAsync(host_dst, gpu_src, size, stream);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << port::Printf(
"failed to enqueue async memcpy from device to host: %s; host dst: %p; "
"GPU src: %p; size: %llu=0x%llx",
- ToString(res).c_str(), host_dst, absl::bit_cast<void *>(gpu_src), size,
+ ToString(res).c_str(), host_dst, absl::bit_cast<void*>(gpu_src), size,
size);
return false;
}
VLOG(2) << "successfully enqueued async memcpy d2h of " << size
- << " bytes from " << absl::bit_cast<void *>(gpu_src) << " to "
+ << " bytes from " << absl::bit_cast<void*>(gpu_src) << " to "
<< host_dst << " on stream " << stream;
return true;
}
@@ -1149,13 +1191,19 @@
uint64 size,
CUstream stream) {
ScopedActivateContext activation(context);
+ if (size > 0) {
+ CHECK(PointerIsValid(host_src))
+ << "Source pointer is not actually on CPU: " << host_src;
+ CHECK(PointerIsValid(gpu_dst))
+ << "Destination pointer is not actually on GPU: " << gpu_dst;
+ }
CUresult res =
tensorflow::wrap::cuMemcpyHtoDAsync(gpu_dst, host_src, size, stream);
if (res != CUDA_SUCCESS) {
LOG(ERROR) << port::Printf(
"failed to enqueue async memcpy from host to device: %s; GPU dst: %p; "
"host src: %p; size: %llu=0x%llx",
- ToString(res).c_str(), absl::bit_cast<void *>(gpu_dst), host_src, size,
+ ToString(res).c_str(), absl::bit_cast<void*>(gpu_dst), host_src, size,
size);
return false;
}
@@ -1170,6 +1218,12 @@
uint64 size,
CUstream stream) {
ScopedActivateContext activation(context);
+ if (size > 0) {
+ CHECK(PointerIsValid(gpu_src))
+ << "Source pointer is not actually on GPU: " << gpu_src;
+ CHECK(PointerIsValid(gpu_dst))
+ << "Destination pointer is not actually on GPU: " << gpu_dst;
+ }
CUresult result =
tensorflow::wrap::cuMemcpyDtoDAsync(gpu_dst, gpu_src, size, stream);
if (result != CUDA_SUCCESS) {
@@ -1178,10 +1232,10 @@
"; GPU dst: %p on %s %s"
"; GPU src: %p on %s %s"
"; can access? %s; size: %llu=0x%llx",
- ToString(result).c_str(), absl::bit_cast<void *>(gpu_dst),
+ ToString(result).c_str(), absl::bit_cast<void*>(gpu_dst),
CUDAPointerToMemorySpaceString(gpu_dst).c_str(),
CUDAPointerToDeviceString(gpu_dst).c_str(),
- absl::bit_cast<void *>(gpu_src),
+ absl::bit_cast<void*>(gpu_src),
CUDAPointerToMemorySpaceString(gpu_src).c_str(),
CUDAPointerToDeviceString(gpu_src).c_str(),
CUDAPointersToCanAccessString(gpu_src, gpu_dst).c_str(), size, size);
@@ -1289,13 +1343,13 @@
return port::Status(
port::error::NOT_FOUND,
port::Printf("not a device pointer %p; %s",
- reinterpret_cast<void *>(dptr), ToString(result).c_str()));
+ reinterpret_cast<void*>(dptr), ToString(result).c_str()));
}
return port::Status(
port::error::INTERNAL,
port::Printf("failed to get pointer into for device pointer %p; %s",
- reinterpret_cast<void *>(dptr), ToString(result).c_str()));
+ reinterpret_cast<void*>(dptr), ToString(result).c_str()));
}
/* static */ port::StatusOr<CUdevice> GpuDriver::GetPointerDevice(
diff --git a/tensorflow/stream_executor/cuda/cudart_stub.cc b/tensorflow/stream_executor/cuda/cudart_stub.cc
new file mode 100644
index 0000000..c5fc43d
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cudart_stub.cc
@@ -0,0 +1,121 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This file wraps cuda runtime calls with dso loader so that we don't need to
+// have explicit linking to libcuda.
+
+#include "cuda/include/cuda_runtime_api.h"
+#include "tensorflow/stream_executor/lib/env.h"
+#include "tensorflow/stream_executor/platform/dso_loader.h"
+
+namespace {
+void *GetDsoHandle() {
+ static auto handle = [] {
+ void *result = nullptr;
+ using DsoLoader = stream_executor::internal::DsoLoader;
+ DsoLoader::GetLibcudartDsoHandle(&result).IgnoreError();
+ return result;
+ }();
+ return handle;
+}
+
+template <typename T>
+T LoadSymbol(const char *symbol_name) {
+ void *symbol = nullptr;
+ auto env = stream_executor::port::Env::Default();
+ env->GetSymbolFromLibrary(GetDsoHandle(), symbol_name, &symbol).IgnoreError();
+ return reinterpret_cast<T>(symbol);
+}
+cudaError_t GetSymbolNotFoundError() {
+ return cudaErrorSharedObjectSymbolNotFound;
+}
+const char *GetSymbolNotFoundStrError() {
+ return "cudaErrorSharedObjectSymbolNotFound";
+}
+} // namespace
+
+// Code below is auto-generated.
+extern "C" {
+cudaError_t CUDART_CB cudaFree(void *devPtr) {
+ using FuncPtr = cudaError_t (*)(void *devPtr);
+ static auto func_ptr = LoadSymbol<FuncPtr>("cudaFree");
+ if (!func_ptr) return GetSymbolNotFoundError();
+ return func_ptr(devPtr);
+}
+
+cudaError_t CUDART_CB cudaGetDevice(int *device) {
+ using FuncPtr = cudaError_t (*)(int *device);
+ static auto func_ptr = LoadSymbol<FuncPtr>("cudaGetDevice");
+ if (!func_ptr) return GetSymbolNotFoundError();
+ return func_ptr(device);
+}
+
+cudaError_t CUDART_CB cudaGetDeviceProperties(cudaDeviceProp *prop,
+ int device) {
+ using FuncPtr = cudaError_t (*)(cudaDeviceProp * prop, int device);
+ static auto func_ptr = LoadSymbol<FuncPtr>("cudaGetDeviceProperties");
+ if (!func_ptr) return GetSymbolNotFoundError();
+ return func_ptr(prop, device);
+}
+
+const char *CUDART_CB cudaGetErrorString(cudaError_t error) {
+ using FuncPtr = const char *(*)(cudaError_t error);
+ static auto func_ptr = LoadSymbol<FuncPtr>("cudaGetErrorString");
+ if (!func_ptr) return GetSymbolNotFoundStrError();
+ return func_ptr(error);
+}
+
+cudaError_t CUDART_CB cudaSetDevice(int device) {
+ using FuncPtr = cudaError_t (*)(int device);
+ static auto func_ptr = LoadSymbol<FuncPtr>("cudaSetDevice");
+ if (!func_ptr) return GetSymbolNotFoundError();
+ return func_ptr(device);
+}
+
+cudaError_t CUDART_CB cudaStreamAddCallback(cudaStream_t stream,
+ cudaStreamCallback_t callback,
+ void *userData,
+ unsigned int flags) {
+ using FuncPtr =
+ cudaError_t (*)(cudaStream_t stream, cudaStreamCallback_t callback,
+ void *userData, unsigned int flags);
+ static auto func_ptr = LoadSymbol<FuncPtr>("cudaStreamAddCallback");
+ if (!func_ptr) return GetSymbolNotFoundError();
+ return func_ptr(stream, callback, userData, flags);
+}
+
+cudaError_t CUDART_CB cudaGetDeviceCount(int *count) {
+ using FuncPtr = cudaError_t (*)(int *count);
+ static auto func_ptr = LoadSymbol<FuncPtr>("cudaGetDeviceCount");
+ if (!func_ptr) return GetSymbolNotFoundError();
+ return func_ptr(count);
+}
+
+cudaError_t CUDART_CB cudaPointerGetAttributes(
+ struct cudaPointerAttributes *attributes, const void *ptr) {
+ using FuncPtr = cudaError_t (*)(struct cudaPointerAttributes * attributes,
+ const void *ptr);
+ static auto func_ptr = LoadSymbol<FuncPtr>("cudaPointerGetAttributes");
+ if (!func_ptr) return GetSymbolNotFoundError();
+ return func_ptr(attributes, ptr);
+}
+
+cudaError_t CUDART_CB cudaGetLastError() {
+ using FuncPtr = cudaError_t (*)();
+ static auto func_ptr = LoadSymbol<FuncPtr>("cudaGetLastError");
+ if (!func_ptr) return GetSymbolNotFoundError();
+ return func_ptr();
+}
+} // extern "C"
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index 1d9a2be..24c2948 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -1648,21 +1648,9 @@
return false;
}
- // Applies local response normalization to the values from
- // input_data and writes the result to output_data. See comments on
- // NormalizeDescriptor for a description of local response
- // normalization.
- virtual bool DoNormalize(Stream* stream,
- const dnn::NormalizeDescriptor& normalize_descriptor,
- const DeviceMemory<float>& input_data,
- DeviceMemory<float>* output_data) = 0;
-
// Applies local response normalization to the values from input_data and
// writes the result to output_data.
//
- // Similar to DoNormalize, but normalizes across feature maps and allows for
- // specifying the dimensions of the tensor.
- //
// See comments on NormalizeDescriptor for a description of local response
// normalization.
virtual bool DoNormalizeWithDimensions(
diff --git a/tensorflow/stream_executor/gpu/BUILD b/tensorflow/stream_executor/gpu/BUILD
index 78fe478..e681238 100644
--- a/tensorflow/stream_executor/gpu/BUILD
+++ b/tensorflow/stream_executor/gpu/BUILD
@@ -25,7 +25,7 @@
cc_library(
name = "gpu_activation_header",
- hdrs = if_gpu_is_configured(["gpu_activation.h"]),
+ hdrs = ["gpu_activation.h"],
deps = ["//tensorflow/stream_executor/platform"],
)
@@ -33,13 +33,13 @@
name = "gpu_activation",
srcs = if_gpu_is_configured(["gpu_activation.cc"]),
hdrs = if_gpu_is_configured(["gpu_activation.h"]),
- deps = [
+ deps = if_gpu_is_configured([
":gpu_activation_header",
":gpu_driver_header",
"//tensorflow/stream_executor",
"//tensorflow/stream_executor:stream_executor_internal",
"//tensorflow/stream_executor/platform",
- ],
+ ]),
)
cc_library(
diff --git a/tensorflow/stream_executor/platform/default/dso_loader.cc b/tensorflow/stream_executor/platform/default/dso_loader.cc
index 668eeee..8592455 100644
--- a/tensorflow/stream_executor/platform/default/dso_loader.cc
+++ b/tensorflow/stream_executor/platform/default/dso_loader.cc
@@ -117,6 +117,13 @@
#endif
}
+/* static */ port::Status DsoLoader::GetLibcudartDsoHandle(void** dso_handle) {
+ return GetDsoHandle(FindDsoPath(port::Env::Default()->FormatLibraryFileName(
+ "cudart", GetCudaVersion()),
+ GetCudaLibraryDirPath()),
+ dso_handle);
+}
+
static mutex& GetRpathMutex() {
static mutex* mu = new mutex;
return *mu;
@@ -282,6 +289,12 @@
return result;
}
+/* static */ port::StatusOr<void*> CachedDsoLoader::GetLibcudartDsoHandle() {
+ static port::StatusOr<void*> result =
+ FetchHandleResult(DsoLoader::GetLibcudartDsoHandle);
+ return result;
+}
+
/* static */ port::StatusOr<void*> CachedDsoLoader::FetchHandleResult(
std::function<port::Status(void**)> load_dso) {
void* handle;
diff --git a/tensorflow/stream_executor/platform/default/dso_loader.h b/tensorflow/stream_executor/platform/default/dso_loader.h
index 806f65b..92c0db7 100644
--- a/tensorflow/stream_executor/platform/default/dso_loader.h
+++ b/tensorflow/stream_executor/platform/default/dso_loader.h
@@ -46,6 +46,7 @@
static port::Status GetCurandDsoHandle(void** dso_handle);
static port::Status GetLibcudaDsoHandle(void** dso_handle);
static port::Status GetLibcuptiDsoHandle(void** dso_handle);
+ static port::Status GetLibcudartDsoHandle(void** dso_handle);
// Registers a new binary-relative path to use as a dlopen search path.
static void RegisterRpath(absl::string_view path);
@@ -101,6 +102,7 @@
static port::StatusOr<void*> GetCurandDsoHandle();
static port::StatusOr<void*> GetLibcudaDsoHandle();
static port::StatusOr<void*> GetLibcuptiDsoHandle();
+ static port::StatusOr<void*> GetLibcudartDsoHandle();
private:
// Fetches a DSO handle via "load_dso" and returns the StatusOr form of the
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index e7485ca..2577d38 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -281,6 +281,12 @@
}
}
+port::Status Stream::RefreshStatus() {
+ port::Status status = parent_->GetStatus(this);
+ CheckStatus(status);
+ return status;
+}
+
Stream &Stream::Init() {
VLOG_CALL();
@@ -431,172 +437,6 @@
return *this;
}
-Stream &Stream::ThenFusedConvolveWithScratch(
- const dnn::BatchDescriptor &conv_input_descriptor,
- const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
- const dnn::FilterDescriptor &filter_descriptor,
- const DeviceMemory<int8> &filter_data,
- const dnn::ConvolutionDescriptor &convolution_descriptor,
- const DeviceMemory<int8> &side_input_data, float side_input_scale,
- const dnn::BatchDescriptor &bias_descriptor,
- const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
- const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output,
- ScratchAllocator *scratch_allocator) {
- VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
- PARAM(conv_input_scale), PARAM(filter_descriptor),
- PARAM(filter_data), PARAM(convolution_descriptor),
- PARAM(side_input_data), PARAM(side_input_scale),
- PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
- PARAM(output_descriptor), PARAM(output));
-
- if (ok()) {
- if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
- CheckError(dnn->DoFusedConvolve(
- this, conv_input_descriptor, conv_input_data, conv_input_scale,
- filter_descriptor, filter_data, convolution_descriptor,
- side_input_data, side_input_scale, bias_descriptor, biases,
- activation_mode, output_descriptor, output, scratch_allocator,
- dnn::AlgorithmConfig(), /*output_profile_result=*/nullptr));
- } else {
- SetErrorAndLogNoDnnSupport();
- }
- }
- return *this;
-}
-
-Stream &Stream::ThenFusedConvolveWithScratch(
- const dnn::BatchDescriptor &conv_input_descriptor,
- const DeviceMemory<Eigen::half> &conv_input_data, float conv_input_scale,
- const dnn::FilterDescriptor &filter_descriptor,
- const DeviceMemory<Eigen::half> &filter_data,
- const dnn::ConvolutionDescriptor &convolution_descriptor,
- const DeviceMemory<Eigen::half> &side_input_data, float side_input_scale,
- const dnn::BatchDescriptor &bias_descriptor,
- const DeviceMemory<Eigen::half> &biases,
- dnn::ActivationMode activation_mode,
- const dnn::BatchDescriptor &output_descriptor,
- DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator) {
- VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
- PARAM(conv_input_scale), PARAM(filter_descriptor),
- PARAM(filter_data), PARAM(convolution_descriptor),
- PARAM(side_input_data), PARAM(side_input_scale),
- PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
- PARAM(output_descriptor), PARAM(output));
-
- if (ok()) {
- if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
- CheckError(dnn->DoFusedConvolve(
- this, conv_input_descriptor, conv_input_data, conv_input_scale,
- filter_descriptor, filter_data, convolution_descriptor,
- side_input_data, side_input_scale, bias_descriptor, biases,
- activation_mode, output_descriptor, output, scratch_allocator,
- dnn::AlgorithmConfig(), /*output_profile_result=*/nullptr));
- } else {
- SetErrorAndLogNoDnnSupport();
- }
- }
- return *this;
-}
-
-Stream &Stream::ThenFusedConvolveWithScratch(
- const dnn::BatchDescriptor &conv_input_descriptor,
- const DeviceMemory<float> &conv_input_data, float conv_input_scale,
- const dnn::FilterDescriptor &filter_descriptor,
- const DeviceMemory<float> &filter_data,
- const dnn::ConvolutionDescriptor &convolution_descriptor,
- const DeviceMemory<float> &side_input_data, float side_input_scale,
- const dnn::BatchDescriptor &bias_descriptor,
- const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
- const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
- ScratchAllocator *scratch_allocator) {
- VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
- PARAM(conv_input_scale), PARAM(filter_descriptor),
- PARAM(filter_data), PARAM(convolution_descriptor),
- PARAM(side_input_data), PARAM(side_input_scale),
- PARAM(bias_descriptor), PARAM(biases), PARAM(activation_mode),
- PARAM(output_descriptor), PARAM(output));
-
- if (ok()) {
- if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
- CheckError(dnn->DoFusedConvolve(
- this, conv_input_descriptor, conv_input_data, conv_input_scale,
- filter_descriptor, filter_data, convolution_descriptor,
- side_input_data, side_input_scale, bias_descriptor, biases,
- activation_mode, output_descriptor, output, scratch_allocator,
- dnn::AlgorithmConfig(), /*output_profile_result=*/nullptr));
- } else {
- SetErrorAndLogNoDnnSupport();
- }
- }
- return *this;
-}
-
-Stream &Stream::ThenConvolveWithScratch(
- const dnn::BatchDescriptor &input_descriptor,
- const DeviceMemory<Eigen::half> &input_data,
- const dnn::FilterDescriptor &filter_descriptor,
- const DeviceMemory<Eigen::half> &filter_data,
- const dnn::ConvolutionDescriptor &convolution_descriptor,
- const dnn::BatchDescriptor &output_descriptor,
- DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator) {
- VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
- PARAM(filter_descriptor), PARAM(filter_data),
- PARAM(convolution_descriptor), PARAM(output_descriptor),
- PARAM(output));
-
- if (ok()) {
- if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
- DeviceMemory<uint8> scratch_memory;
- dnn::AlgorithmDesc algorithm_desc;
- CheckStatus(dnn->PrepareForConvolution(
- dnn::ConvolutionKind::FORWARD, this, input_descriptor, input_data,
- filter_descriptor, filter_data, output_descriptor, *output,
- convolution_descriptor, dnn::AlgorithmConfig(), scratch_allocator,
- &algorithm_desc, &scratch_memory));
- CheckError(dnn->DoConvolve(
- this, input_descriptor, input_data, filter_descriptor, filter_data,
- convolution_descriptor, output_descriptor, output, algorithm_desc,
- &scratch_memory, nullptr));
- } else {
- SetErrorAndLogNoDnnSupport();
- }
- }
- return *this;
-}
-
-Stream &Stream::ThenConvolveWithScratch(
- const dnn::BatchDescriptor &input_descriptor,
- const DeviceMemory<float> &input_data,
- const dnn::FilterDescriptor &filter_descriptor,
- const DeviceMemory<float> &filter_data,
- const dnn::ConvolutionDescriptor &convolution_descriptor,
- const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
- ScratchAllocator *scratch_allocator) {
- VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
- PARAM(filter_descriptor), PARAM(filter_data),
- PARAM(convolution_descriptor), PARAM(output_descriptor),
- PARAM(output));
-
- if (ok()) {
- if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
- DeviceMemory<uint8> scratch_memory;
- dnn::AlgorithmDesc algorithm_desc;
- CheckStatus(dnn->PrepareForConvolution(
- dnn::ConvolutionKind::FORWARD, this, input_descriptor, input_data,
- filter_descriptor, filter_data, output_descriptor, *output,
- convolution_descriptor, dnn::AlgorithmConfig(), scratch_allocator,
- &algorithm_desc, &scratch_memory));
- CheckError(dnn->DoConvolve(
- this, input_descriptor, input_data, filter_descriptor, filter_data,
- convolution_descriptor, output_descriptor, output, algorithm_desc,
- &scratch_memory, nullptr));
- } else {
- SetErrorAndLogNoDnnSupport();
- }
- }
- return *this;
-}
-
Stream &Stream::ThenFusedConvolveWithAlgorithm(
const dnn::BatchDescriptor &conv_input_descriptor,
const DeviceMemory<double> &conv_input_data, double conv_input_scale,
@@ -876,24 +716,6 @@
return *this;
}
-Stream &Stream::ThenFusedConvolve(
- const dnn::BatchDescriptor &conv_input_descriptor,
- const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
- const dnn::FilterDescriptor &filter_descriptor,
- const DeviceMemory<int8> &filter_data,
- const dnn::ConvolutionDescriptor &convolution_descriptor,
- const DeviceMemory<int8> &side_input_data, float side_input_scale,
- const dnn::BatchDescriptor &bias_descriptor,
- const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
- const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output) {
- return ThenFusedConvolveWithScratch(
- conv_input_descriptor, conv_input_data, conv_input_scale,
- filter_descriptor, filter_data, convolution_descriptor, side_input_data,
- side_input_scale, bias_descriptor, biases, activation_mode,
- output_descriptor, output,
- /*scratch_allocator=*/nullptr);
-}
-
Stream &Stream::ThenConvolve(
const dnn::BatchDescriptor &input_descriptor,
const DeviceMemory<float> &input_data,
@@ -902,10 +724,11 @@
const dnn::ConvolutionDescriptor &convolution_descriptor,
const dnn::BatchDescriptor &output_descriptor,
DeviceMemory<float> *output) {
- return ThenConvolveWithScratch(input_descriptor, input_data,
- filter_descriptor, filter_data,
- convolution_descriptor, output_descriptor,
- output, /*scratch_allocator=*/nullptr);
+ return ThenConvolveWithAlgorithm(
+ input_descriptor, input_data, filter_descriptor, filter_data,
+ convolution_descriptor, output_descriptor, output,
+ /*scratch_allocator=*/nullptr, dnn::AlgorithmConfig(),
+ /*output_profile_result=*/nullptr);
}
Stream &Stream::ThenConvolveQuantized(
@@ -995,42 +818,6 @@
return *this;
}
-Stream &Stream::ThenConvolveBackwardDataWithScratch(
- const dnn::FilterDescriptor &filter_descriptor,
- const DeviceMemory<float> &filter_data,
- const dnn::BatchDescriptor &output_descriptor,
- DeviceMemory<float> backward_output_data,
- const dnn::ConvolutionDescriptor &convolution_descriptor,
- const dnn::BatchDescriptor &input_descriptor,
- DeviceMemory<float> *backward_input_data,
- ScratchAllocator *scratch_allocator) {
- VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
- PARAM(output_descriptor), PARAM(backward_output_data),
- PARAM(convolution_descriptor), PARAM(input_descriptor),
- PARAM(backward_input_data));
-
- if (ok()) {
- if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
- DeviceMemory<uint8> scratch_memory;
- dnn::AlgorithmDesc algorithm_desc;
- CheckStatus(dnn->PrepareForConvolution(
- dnn::ConvolutionKind::BACKWARD_DATA, this, input_descriptor,
- *backward_input_data, filter_descriptor, filter_data,
- output_descriptor, backward_output_data, convolution_descriptor,
- dnn::AlgorithmConfig(), scratch_allocator, &algorithm_desc,
- &scratch_memory));
- CheckError(dnn->DoConvolveBackwardData(
- this, filter_descriptor, filter_data, output_descriptor,
- backward_output_data, convolution_descriptor, input_descriptor,
- backward_input_data, algorithm_desc, &scratch_memory,
- /*output_profile_result=*/nullptr));
- } else {
- SetErrorAndLogNoDnnSupport();
- }
- }
- return *this;
-}
-
Stream &Stream::ThenConvolveBackwardDataWithAlgorithm(
const dnn::FilterDescriptor &filter_descriptor,
const DeviceMemory<double> &filter_data,
@@ -1166,92 +953,6 @@
return *this;
}
-Stream &Stream::ThenConvolveBackwardDataWithScratch(
- const dnn::FilterDescriptor &filter_descriptor,
- const DeviceMemory<Eigen::half> &filter_data,
- const dnn::BatchDescriptor &output_descriptor,
- DeviceMemory<Eigen::half> backward_output_data,
- const dnn::ConvolutionDescriptor &convolution_descriptor,
- const dnn::BatchDescriptor &input_descriptor,
- DeviceMemory<Eigen::half> *backward_input_data,
- ScratchAllocator *scratch_allocator) {
- VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
- PARAM(output_descriptor), PARAM(backward_output_data),
- PARAM(convolution_descriptor), PARAM(input_descriptor),
- PARAM(backward_input_data));
-
- if (ok()) {
- if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
- DeviceMemory<uint8> scratch_memory;
- dnn::AlgorithmDesc algorithm_desc;
- CheckStatus(dnn->PrepareForConvolution(
- dnn::ConvolutionKind::BACKWARD_DATA, this, input_descriptor,
- *backward_input_data, filter_descriptor, filter_data,
- output_descriptor, backward_output_data, convolution_descriptor,
- dnn::AlgorithmConfig(), scratch_allocator, &algorithm_desc,
- &scratch_memory));
- CheckError(dnn->DoConvolveBackwardData(
- this, filter_descriptor, filter_data, output_descriptor,
- backward_output_data, convolution_descriptor, input_descriptor,
- backward_input_data, algorithm_desc, &scratch_memory,
- /*output_profile_result=*/nullptr));
- } else {
- SetErrorAndLogNoDnnSupport();
- }
- }
- return *this;
-}
-
-Stream &Stream::ThenConvolveBackwardData(
- const dnn::FilterDescriptor &filter_descriptor,
- const DeviceMemory<float> &filter_data,
- const dnn::BatchDescriptor &output_descriptor,
- DeviceMemory<float> backward_output_data,
- const dnn::ConvolutionDescriptor &convolution_descriptor,
- const dnn::BatchDescriptor &input_descriptor,
- DeviceMemory<float> *backward_input_data) {
- return ThenConvolveBackwardDataWithScratch(
- filter_descriptor, filter_data, output_descriptor, backward_output_data,
- convolution_descriptor, input_descriptor, backward_input_data,
- /*scratch_allocator=*/nullptr);
-}
-
-Stream &Stream::ThenConvolveBackwardFilterWithScratch(
- const dnn::BatchDescriptor &input_descriptor,
- const DeviceMemory<float> &input_data,
- const dnn::BatchDescriptor &output_descriptor,
- DeviceMemory<float> backward_output_data,
- const dnn::ConvolutionDescriptor &convolution_descriptor,
- const dnn::FilterDescriptor &filter_descriptor,
- DeviceMemory<float> *backward_filter_data,
- ScratchAllocator *scratch_allocator) {
- VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
- PARAM(output_descriptor), PARAM(backward_output_data),
- PARAM(convolution_descriptor), PARAM(filter_descriptor),
- PARAM(backward_filter_data));
-
- if (ok()) {
- if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
- DeviceMemory<uint8> scratch_memory;
- dnn::AlgorithmDesc algorithm_desc;
- CheckStatus(dnn->PrepareForConvolution(
- dnn::ConvolutionKind::BACKWARD_FILTER, this, input_descriptor,
- input_data, filter_descriptor, *backward_filter_data,
- output_descriptor, backward_output_data, convolution_descriptor,
- dnn::AlgorithmConfig(), scratch_allocator, &algorithm_desc,
- &scratch_memory));
- CheckError(dnn->DoConvolveBackwardFilter(
- this, input_descriptor, input_data, output_descriptor,
- backward_output_data, convolution_descriptor, filter_descriptor,
- backward_filter_data, algorithm_desc, &scratch_memory,
- /*output_profile_result=*/nullptr));
- } else {
- SetErrorAndLogNoDnnSupport();
- }
- }
- return *this;
-}
-
Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm(
const dnn::BatchDescriptor &input_descriptor,
const DeviceMemory<double> &input_data,
@@ -1342,42 +1043,6 @@
return *this;
}
-Stream &Stream::ThenConvolveBackwardFilterWithScratch(
- const dnn::BatchDescriptor &input_descriptor,
- const DeviceMemory<Eigen::half> &input_data,
- const dnn::BatchDescriptor &output_descriptor,
- DeviceMemory<Eigen::half> backward_output_data,
- const dnn::ConvolutionDescriptor &convolution_descriptor,
- const dnn::FilterDescriptor &filter_descriptor,
- DeviceMemory<Eigen::half> *backward_filter_data,
- ScratchAllocator *scratch_allocator) {
- VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
- PARAM(output_descriptor), PARAM(backward_output_data),
- PARAM(convolution_descriptor), PARAM(filter_descriptor),
- PARAM(backward_filter_data));
-
- if (ok()) {
- if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
- DeviceMemory<uint8> scratch_memory;
- dnn::AlgorithmDesc algorithm_desc;
- CheckStatus(dnn->PrepareForConvolution(
- dnn::ConvolutionKind::BACKWARD_FILTER, this, input_descriptor,
- input_data, filter_descriptor, *backward_filter_data,
- output_descriptor, backward_output_data, convolution_descriptor,
- dnn::AlgorithmConfig(), scratch_allocator, &algorithm_desc,
- &scratch_memory));
- CheckError(dnn->DoConvolveBackwardFilter(
- this, input_descriptor, input_data, output_descriptor,
- backward_output_data, convolution_descriptor, filter_descriptor,
- backward_filter_data, algorithm_desc, &scratch_memory,
- /*output_profile_result=*/nullptr));
- } else {
- SetErrorAndLogNoDnnSupport();
- }
- }
- return *this;
-}
-
Stream &Stream::ThenConvolveBackwardFilterWithAlgorithm(
const dnn::BatchDescriptor &input_descriptor,
const DeviceMemory<Eigen::half> &input_data,
@@ -1423,20 +1088,6 @@
return *this;
}
-Stream &Stream::ThenConvolveBackwardFilter(
- const dnn::BatchDescriptor &input_descriptor,
- const DeviceMemory<float> &input_data,
- const dnn::BatchDescriptor &output_descriptor,
- DeviceMemory<float> backward_output_data,
- const dnn::ConvolutionDescriptor &convolution_descriptor,
- const dnn::FilterDescriptor &filter_descriptor,
- DeviceMemory<float> *backward_filter_data) {
- return ThenConvolveBackwardFilterWithScratch(
- input_descriptor, input_data, output_descriptor, backward_output_data,
- convolution_descriptor, filter_descriptor, backward_filter_data,
- /*scratch_allocator=*/nullptr);
-}
-
template <typename T>
Stream &Stream::ThenConvolveBackwardBiasImpl(
const dnn::BatchDescriptor &input_descriptor,
@@ -1742,22 +1393,6 @@
return *this;
}
-Stream &Stream::ThenNormalize(
- const dnn::NormalizeDescriptor &normalize_descriptor,
- const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data) {
- VLOG_CALL(PARAM(normalize_descriptor), PARAM(input_data), PARAM(output_data));
-
- if (ok()) {
- if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
- CheckError(dnn->DoNormalize(this, normalize_descriptor, input_data,
- output_data));
- } else {
- SetErrorAndLogNoDnnSupport();
- }
- }
- return *this;
-}
-
Stream &Stream::ThenNormalizeWithDimensions(
const dnn::NormalizeDescriptor &normalize_descriptor,
const dnn::BatchDescriptor &dimensions,
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h
index f698d50..3e67d55 100644
--- a/tensorflow/stream_executor/stream.h
+++ b/tensorflow/stream_executor/stream.h
@@ -109,6 +109,17 @@
// stream.
bool ok() const { return !InErrorState(); }
+ // Retrieves execution status back into the stream from the underlying
+ // implementation without blocking the stream.
+ //
+ // Normally, Stream::BlockHostUntilDone is used to get execution status.
+ // However, some devices use out-of-band mechnanisms to ensure their streams
+ // have finished on-device work, without needing to block the streams. (These
+ // devices should also override AllowsSyncOnCompletion to return false.) For
+ // these devices, this method can be used after work is finished to retrieve
+ // execution status.
+ port::Status RefreshStatus() LOCKS_EXCLUDED(mu_);
+
// Initialize the stream. This must be performed before entraining any other
// operations.
Stream &Init() LOCKS_EXCLUDED(mu_);
@@ -262,19 +273,6 @@
DeviceMemory<float> *scale_backprop,
DeviceMemory<float> *offset_backprop);
- // TODO(leary) add double-precision version of this interface.
- Stream &ThenFusedConvolve(
- const dnn::BatchDescriptor &conv_input_descriptor,
- const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
- const dnn::FilterDescriptor &filter_descriptor,
- const DeviceMemory<int8> &filter_data,
- const dnn::ConvolutionDescriptor &convolution_descriptor,
- const DeviceMemory<int8> &side_input_data, float side_input_scale,
- const dnn::BatchDescriptor &bias_descriptor,
- const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
- const dnn::BatchDescriptor &output_descriptor,
- DeviceMemory<int8> *output);
-
Stream &ThenConvolve(const dnn::BatchDescriptor &input_descriptor,
const DeviceMemory<float> &input_data,
const dnn::FilterDescriptor &filter_descriptor,
@@ -303,61 +301,6 @@
const dnn::BatchDescriptor &output_descriptor,
DeviceMemory<float> *output_data);
- Stream &ThenFusedConvolveWithScratch(
- const dnn::BatchDescriptor &conv_input_descriptor,
- const DeviceMemory<int8> &conv_input_data, float conv_input_scale,
- const dnn::FilterDescriptor &filter_descriptor,
- const DeviceMemory<int8> &filter_data,
- const dnn::ConvolutionDescriptor &convolution_descriptor,
- const DeviceMemory<int8> &side_input_data, float side_input_scale,
- const dnn::BatchDescriptor &bias_descriptor,
- const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
- const dnn::BatchDescriptor &output_descriptor, DeviceMemory<int8> *output,
- ScratchAllocator *scratch_allocator);
-
- Stream &ThenFusedConvolveWithScratch(
- const dnn::BatchDescriptor &conv_input_descriptor,
- const DeviceMemory<Eigen::half> &conv_input_data, float conv_input_scale,
- const dnn::FilterDescriptor &filter_descriptor,
- const DeviceMemory<Eigen::half> &filter_data,
- const dnn::ConvolutionDescriptor &convolution_descriptor,
- const DeviceMemory<Eigen::half> &side_input_data, float side_input_scale,
- const dnn::BatchDescriptor &bias_descriptor,
- const DeviceMemory<Eigen::half> &biases,
- dnn::ActivationMode activation_mode,
- const dnn::BatchDescriptor &output_descriptor,
- DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator);
-
- Stream &ThenFusedConvolveWithScratch(
- const dnn::BatchDescriptor &conv_input_descriptor,
- const DeviceMemory<float> &conv_input_data, float conv_input_scale,
- const dnn::FilterDescriptor &filter_descriptor,
- const DeviceMemory<float> &filter_data,
- const dnn::ConvolutionDescriptor &convolution_descriptor,
- const DeviceMemory<float> &side_input_data, float side_input_scale,
- const dnn::BatchDescriptor &bias_descriptor,
- const DeviceMemory<float> &biases, dnn::ActivationMode activation_mode,
- const dnn::BatchDescriptor &output_descriptor,
- DeviceMemory<float> *output, ScratchAllocator *scratch_allocator);
-
- Stream &ThenConvolveWithScratch(
- const dnn::BatchDescriptor &input_descriptor,
- const DeviceMemory<Eigen::half> &input_data,
- const dnn::FilterDescriptor &filter_descriptor,
- const DeviceMemory<Eigen::half> &filter_data,
- const dnn::ConvolutionDescriptor &convolution_descriptor,
- const dnn::BatchDescriptor &output_descriptor,
- DeviceMemory<Eigen::half> *output, ScratchAllocator *scratch_allocator);
-
- Stream &ThenConvolveWithScratch(
- const dnn::BatchDescriptor &input_descriptor,
- const DeviceMemory<float> &input_data,
- const dnn::FilterDescriptor &filter_descriptor,
- const DeviceMemory<float> &filter_data,
- const dnn::ConvolutionDescriptor &convolution_descriptor,
- const dnn::BatchDescriptor &output_descriptor,
- DeviceMemory<float> *output, ScratchAllocator *scratch_allocator);
-
Stream &ThenConvolveWithAlgorithm(
const dnn::BatchDescriptor &input_descriptor,
const DeviceMemory<double> &input_data,
@@ -458,35 +401,6 @@
const dnn::BatchDescriptor &output_descriptor,
DeviceMemory<float> *output);
- Stream &ThenConvolveBackwardData(
- const dnn::FilterDescriptor &filter_descriptor,
- const DeviceMemory<float> &filter_data,
- const dnn::BatchDescriptor &output_descriptor,
- DeviceMemory<float> backward_output_data,
- const dnn::ConvolutionDescriptor &convolution_descriptor,
- const dnn::BatchDescriptor &input_descriptor,
- DeviceMemory<float> *backward_input_data);
-
- Stream &ThenConvolveBackwardDataWithScratch(
- const dnn::FilterDescriptor &filter_descriptor,
- const DeviceMemory<float> &filter_data,
- const dnn::BatchDescriptor &output_descriptor,
- DeviceMemory<float> backward_output_data,
- const dnn::ConvolutionDescriptor &convolution_descriptor,
- const dnn::BatchDescriptor &input_descriptor,
- DeviceMemory<float> *backward_input_data,
- ScratchAllocator *scratch_allocator);
-
- Stream &ThenConvolveBackwardDataWithScratch(
- const dnn::FilterDescriptor &filter_descriptor,
- const DeviceMemory<Eigen::half> &filter_data,
- const dnn::BatchDescriptor &output_descriptor,
- DeviceMemory<Eigen::half> backward_output_data,
- const dnn::ConvolutionDescriptor &convolution_descriptor,
- const dnn::BatchDescriptor &input_descriptor,
- DeviceMemory<Eigen::half> *backward_input_data,
- ScratchAllocator *scratch_allocator);
-
Stream &ThenConvolveBackwardDataWithAlgorithm(
const dnn::FilterDescriptor &filter_descriptor,
const DeviceMemory<double> &filter_data,
@@ -523,35 +437,6 @@
const dnn::AlgorithmConfig &algorithm_config,
dnn::ProfileResult *output_profile_result);
- Stream &ThenConvolveBackwardFilter(
- const dnn::BatchDescriptor &input_descriptor,
- const DeviceMemory<float> &input_data,
- const dnn::BatchDescriptor &output_descriptor,
- DeviceMemory<float> backward_output_data,
- const dnn::ConvolutionDescriptor &convolution_descriptor,
- const dnn::FilterDescriptor &filter_descriptor,
- DeviceMemory<float> *backward_filter_data);
-
- Stream &ThenConvolveBackwardFilterWithScratch(
- const dnn::BatchDescriptor &input_descriptor,
- const DeviceMemory<float> &input_data,
- const dnn::BatchDescriptor &output_descriptor,
- DeviceMemory<float> backward_output_data,
- const dnn::ConvolutionDescriptor &convolution_descriptor,
- const dnn::FilterDescriptor &filter_descriptor,
- DeviceMemory<float> *backward_filter_data,
- ScratchAllocator *scratch_allocator);
-
- Stream &ThenConvolveBackwardFilterWithScratch(
- const dnn::BatchDescriptor &input_descriptor,
- const DeviceMemory<Eigen::half> &input_data,
- const dnn::BatchDescriptor &output_descriptor,
- DeviceMemory<Eigen::half> backward_output_data,
- const dnn::ConvolutionDescriptor &convolution_descriptor,
- const dnn::FilterDescriptor &filter_descriptor,
- DeviceMemory<Eigen::half> *backward_filter_data,
- ScratchAllocator *scratch_allocator);
-
Stream &ThenConvolveBackwardFilterWithAlgorithm(
const dnn::BatchDescriptor &input_descriptor,
const DeviceMemory<double> &input_data,
@@ -684,12 +569,6 @@
DeviceMemory<Eigen::half> *output_diff_data,
ScratchAllocator *workspace_allocator = nullptr);
- Stream &ThenNormalize(const dnn::NormalizeDescriptor &normalize_descriptor,
- const DeviceMemory<float> &input_data,
- DeviceMemory<float> *output_data);
-
- // Similar to ThenNormalize, but normalizes across feature maps and allows for
- // specifying the dimensions of the tensor.
Stream &ThenNormalizeWithDimensions(
const dnn::NormalizeDescriptor &normalize_descriptor,
const dnn::BatchDescriptor &dimensions,
diff --git a/tensorflow/stream_executor/stream_executor_internal.h b/tensorflow/stream_executor/stream_executor_internal.h
index 6138554..e234e5d 100644
--- a/tensorflow/stream_executor/stream_executor_internal.h
+++ b/tensorflow/stream_executor/stream_executor_internal.h
@@ -253,6 +253,10 @@
virtual bool StartTimer(Stream *stream, Timer *timer) = 0;
virtual bool StopTimer(Stream *stream, Timer *timer) = 0;
virtual port::Status BlockHostUntilDone(Stream *stream) = 0;
+ virtual port::Status GetStatus(Stream *stream) {
+ return port::Status(port::error::UNIMPLEMENTED,
+ "GetStatus is not supported on this executor.");
+ }
virtual int PlatformDeviceCount() = 0;
virtual port::Status EnablePeerAccessTo(StreamExecutorInterface *other) = 0;
virtual bool CanEnablePeerAccessTo(StreamExecutorInterface *other) = 0;
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index 6f0ba51..c680a02 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -492,6 +492,10 @@
return result;
}
+port::Status StreamExecutor::GetStatus(Stream *stream) {
+ return implementation_->GetStatus(stream);
+}
+
void *StreamExecutor::Allocate(uint64 size) {
if (memory_limit_bytes_ > 0 &&
mem_alloc_bytes_ + size > memory_limit_bytes_) {
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index ad2bc3c..508273e 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -524,6 +524,9 @@
// operations enqueued on the stream before this program point.
port::Status BlockHostUntilDone(Stream *stream);
+ // Without blocking the device, retrieve the current stream status.
+ port::Status GetStatus(Stream *stream);
+
// Synchronously allocates size bytes on the underlying platform and returns
// an opaque void* representing that allocation. In the case of failure,
// nullptr is returned.
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 874d9e8..6c8b445 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -97,6 +97,11 @@
for p in core_proto_sources_relative
])
+# Wrapper for portable protos which currently just creates an empty rule.
+def tf_portable_proto_library(name, proto_deps, **kwargs):
+ _ignore = [kwargs]
+ native.cc_library(name = name, deps = proto_deps)
+
# Sanitize a dependency so that it works correctly from code that includes
# TensorFlow as a submodule.
def clean_dep(dep):
@@ -146,6 +151,12 @@
"//conditions:default": [],
})
+def if_emscripten(a):
+ return select({
+ clean_dep("//tensorflow:emscripten"): a,
+ "//conditions:default": [],
+ })
+
def if_ios(a):
return select({
clean_dep("//tensorflow:ios"): a,
@@ -306,9 +317,19 @@
# LINT.ThenChange(//tensorflow/contrib/android/cmake/CMakeLists.txt)
+def tf_opts_nortti_if_emscripten():
+ return if_emscripten([
+ "-fno-rtti",
+ "-DGOOGLE_PROTOBUF_NO_RTTI",
+ "-DGOOGLE_PROTOBUF_NO_STATIC_INITIALIZER",
+ ])
+
def tf_features_nomodules_if_android():
return if_android(["-use_header_modules"])
+def tf_features_nomodules_if_emscripten():
+ return if_emscripten(["-use_header_modules"])
+
# Given a list of "op_lib_names" (a list of files in the ops directory
# without their .cc extensions), generate a library for that file.
def tf_gen_op_libs(op_lib_names, deps = None, is_external = True):
@@ -1132,7 +1153,7 @@
kwargs["features"] = kwargs.get("features", []) + ["-use_header_modules"]
native.cc_library(
deps = deps + if_cuda_is_configured_compat(cuda_deps + [
- clean_dep("//tensorflow/core:cuda"),
+ clean_dep("//tensorflow/stream_executor/cuda:cudart_stub"),
"@local_config_cuda//cuda:cuda_headers",
]) + if_rocm_is_configured(cuda_deps + [
# rocm_header placeholder
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-aggregation-method.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-aggregation-method.pbtxt
index f79029d..cc2d5c8 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-aggregation-method.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-aggregation-method.pbtxt
@@ -1,6 +1,6 @@
path: "tensorflow.AggregationMethod"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.gradients_impl.AggregationMethod\'>"
+ is_instance: "<class \'tensorflow.python.ops.gradients_util.AggregationMethod\'>"
is_instance: "<type \'object\'>"
member {
name: "ADD_N"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt
index 078f102..2e8ece1 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.-experimental.pbtxt
@@ -32,6 +32,12 @@
label: LABEL_OPTIONAL
type: TYPE_BOOL
}
+ field {
+ name: "collective_nccl"
+ number: 7
+ label: LABEL_OPTIONAL
+ type: TYPE_BOOL
+ }
reserved_range {
start: 2
end: 3
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt
index d2ee0c4..9c7de2c 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-config-proto.pbtxt
@@ -155,6 +155,12 @@
label: LABEL_OPTIONAL
type: TYPE_BOOL
}
+ field {
+ name: "collective_nccl"
+ number: 7
+ label: LABEL_OPTIONAL
+ type: TYPE_BOOL
+ }
reserved_range {
start: 2
end: 3
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-g-p-u-options.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-g-p-u-options.pbtxt
index a2cc074..6c528dd 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-g-p-u-options.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-g-p-u-options.pbtxt
@@ -84,6 +84,18 @@
label: LABEL_OPTIONAL
type: TYPE_STRING
}
+ field {
+ name: "timestamped_allocator"
+ number: 5
+ label: LABEL_OPTIONAL
+ type: TYPE_BOOL
+ }
+ field {
+ name: "pending_cap"
+ number: 6
+ label: LABEL_OPTIONAL
+ type: TYPE_INT32
+ }
nested_type {
name: "VirtualDevices"
field {
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.audio.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.audio.pbtxt
index ce29615..6c57240 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.audio.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.audio.pbtxt
@@ -1,6 +1,10 @@
path: "tensorflow.audio"
tf_module {
member_method {
+ name: "decode_wav"
+ argspec: "args=[\'contents\', \'desired_channels\', \'desired_samples\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], "
+ }
+ member_method {
name: "encode_wav"
argspec: "args=[\'audio\', \'sample_rate\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.experimental.-module.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.experimental.-module.pbtxt
index c364b02..3c5add1 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.experimental.-module.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.experimental.-module.pbtxt
@@ -13,18 +13,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "owned_submodules"
- mtype: "<type \'property\'>"
- }
- member {
- name: "owned_trainable_variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "owned_variables"
- mtype: "<type \'property\'>"
- }
- member {
name: "submodules"
mtype: "<type \'property\'>"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt
index 283cc6a..bb44ba0 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.-model.pbtxt
@@ -135,7 +135,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt
index 95e405a..44fc15e 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.-sequential.pbtxt
@@ -140,7 +140,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-cosine-decay-restarts.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-cosine-decay-restarts.pbtxt
new file mode 100644
index 0000000..58bede5
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-cosine-decay-restarts.pbtxt
@@ -0,0 +1,18 @@
+path: "tensorflow.keras.experimental.CosineDecayRestarts"
+tf_class {
+ is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.CosineDecayRestarts\'>"
+ is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.LearningRateSchedule\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'initial_learning_rate\', \'first_decay_steps\', \'t_mul\', \'m_mul\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'2.0\', \'1.0\', \'0.0\', \'None\'], "
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-cosine-decay.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-cosine-decay.pbtxt
new file mode 100644
index 0000000..f083120
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-linear-cosine-decay.pbtxt
@@ -0,0 +1,18 @@
+path: "tensorflow.keras.experimental.LinearCosineDecay"
+tf_class {
+ is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.LinearCosineDecay\'>"
+ is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.LearningRateSchedule\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'initial_learning_rate\', \'decay_steps\', \'num_periods\', \'alpha\', \'beta\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'0.0\', \'0.001\', \'None\'], "
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-noisy-linear-cosine-decay.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-noisy-linear-cosine-decay.pbtxt
new file mode 100644
index 0000000..8ea3c6b
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-noisy-linear-cosine-decay.pbtxt
@@ -0,0 +1,18 @@
+path: "tensorflow.keras.experimental.NoisyLinearCosineDecay"
+tf_class {
+ is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.NoisyLinearCosineDecay\'>"
+ is_instance: "<class \'tensorflow.python.keras.optimizer_v2.learning_rate_schedule.LearningRateSchedule\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'initial_learning_rate\', \'decay_steps\', \'initial_variance\', \'variance_decay\', \'num_periods\', \'alpha\', \'beta\', \'name\'], varargs=None, keywords=None, defaults=[\'1.0\', \'0.55\', \'0.5\', \'0.0\', \'0.001\', \'None\'], "
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-peephole-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-peephole-l-s-t-m-cell.pbtxt
index 1bfd51c..4f7ace4 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-peephole-l-s-t-m-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.-peephole-l-s-t-m-cell.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.pbtxt
index 24684b9..721c188 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.experimental.pbtxt
@@ -5,6 +5,18 @@
mtype: "<type \'type\'>"
}
member {
+ name: "CosineDecayRestarts"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "LinearCosineDecay"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "NoisyLinearCosineDecay"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "PeepholeLSTMCell"
mtype: "<type \'type\'>"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activation.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activation.pbtxt
index 8a0b8eb..eab888c 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activation.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activation.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activity-regularization.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activity-regularization.pbtxt
index abb3c23..96c7acc 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activity-regularization.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-activity-regularization.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-add.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-add.pbtxt
index b27db4e..9e8aae1 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-add.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-add.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-alpha-dropout.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-alpha-dropout.pbtxt
index 50998ac..01fc730 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-alpha-dropout.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-alpha-dropout.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling1-d.pbtxt
index be17aea..8b6a151 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling1-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling2-d.pbtxt
index 7f21b44..3c78457 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling2-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling3-d.pbtxt
index 2ac86f1..e6e96a0 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average-pooling3-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average.pbtxt
index f6b1dd2..ec2d5b1 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-average.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool1-d.pbtxt
index 3da1f43..afff790 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool1-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool2-d.pbtxt
index a7be5ac..d7ab835 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool2-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool3-d.pbtxt
index c5c29be..6654f86 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-avg-pool3-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-batch-normalization.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-batch-normalization.pbtxt
index 3af3c2a..a328d9f 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-batch-normalization.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-batch-normalization.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-bidirectional.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-bidirectional.pbtxt
index 880d18e..94f3a46 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-bidirectional.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-bidirectional.pbtxt
@@ -115,7 +115,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-concatenate.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-concatenate.pbtxt
index 1eb0cf1..e0eae17 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-concatenate.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-concatenate.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt
index d9394e6..ec8a44c 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt
@@ -196,7 +196,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv1-d.pbtxt
index a0f6dc8..350d49a 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv1-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d-transpose.pbtxt
index 037b92f..9b48eb6 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d-transpose.pbtxt
@@ -108,7 +108,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d.pbtxt
index 6a0d027..1708d6a 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv2-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
index 66b5bd7..5018492 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
@@ -108,7 +108,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d.pbtxt
index e73133f..fd24af3 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-conv3-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution1-d.pbtxt
index 7af6b2b..fbc7609 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution1-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt
index baff492..671a004 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt
@@ -108,7 +108,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d.pbtxt
index 63d30a6..dd6519c 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution2-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
index 7a29cbb..648f480 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
@@ -108,7 +108,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d.pbtxt
index 87c75c0..87a07ea 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-convolution3-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping1-d.pbtxt
index f69104d..6f3a153 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping1-d.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping2-d.pbtxt
index aa05471..a1c418c 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping2-d.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping3-d.pbtxt
index d61f1dd..ad98f9c 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cropping3-d.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt
index e2d05f8..e35403b 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-g-r-u.pbtxt
@@ -116,7 +116,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt
index f650f48..90d03ea 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-cu-d-n-n-l-s-t-m.pbtxt
@@ -116,7 +116,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense-features.pbtxt
similarity index 77%
copy from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt
copy to tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense-features.pbtxt
index 9de7307..ca6a327 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense-features.pbtxt
@@ -1,8 +1,7 @@
-path: "tensorflow.nn.rnn_cell.ResidualWrapper"
+path: "tensorflow.keras.layers.DenseFeatures"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.ResidualWrapper\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
+ is_instance: "<class \'tensorflow.python.feature_column.feature_column_v2.DenseFeatures\'>"
+ is_instance: "<class \'tensorflow.python.feature_column.feature_column_v2._BaseFeaturesLayer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
is_instance: "<class \'tensorflow.python.training.checkpointable.base.Checkpointable\'>"
is_instance: "<type \'object\'>"
@@ -19,10 +18,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
name: "inbound_nodes"
mtype: "<type \'property\'>"
}
@@ -71,18 +66,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
name: "trainable_variables"
mtype: "<type \'property\'>"
}
@@ -104,7 +87,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cell\', \'residual_fn\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'self\', \'feature_columns\', \'trainable\', \'name\'], varargs=None, keywords=kwargs, defaults=[\'True\', \'None\'], "
}
member_method {
name: "add_loss"
@@ -124,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
@@ -136,7 +119,7 @@
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
+ argspec: "args=[\'self\', \'features\', \'cols_to_output_tensors\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
name: "compute_mask"
@@ -159,10 +142,6 @@
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
name: "get_input_at"
argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
}
@@ -202,8 +181,4 @@
name: "set_weights"
argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
}
- member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
- }
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense.pbtxt
index 06e8b6b..ef12b2e 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dense.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt
index 9fdf6f6..eacfb37 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt
@@ -108,7 +108,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dot.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dot.pbtxt
index cbe1020..7928ceb 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dot.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dot.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dropout.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dropout.pbtxt
index 0efba09..a7fa545 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dropout.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-dropout.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-e-l-u.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-e-l-u.pbtxt
index b34c499..483ba65 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-e-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-e-l-u.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-embedding.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-embedding.pbtxt
index 51dd853..4d0e5e1 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-embedding.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-embedding.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-flatten.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-flatten.pbtxt
index dcd18a9..5947047 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-flatten.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-flatten.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u-cell.pbtxt
index f029907..b4efdf3 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u-cell.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u.pbtxt
index 278ae06..db4d981 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-g-r-u.pbtxt
@@ -179,7 +179,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-dropout.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-dropout.pbtxt
index 15cbcfe..1686768 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-dropout.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-dropout.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-noise.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-noise.pbtxt
index 865b898..69bca6a 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-noise.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-gaussian-noise.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt
index 3e17aca..9a4119d 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt
index b160687..2ca1eb1 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt
index 70e8d51..4331adc 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt
index 809dc85..6e91b4a 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt
index 3fbce8c..85887a5 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt
index 70e4103..dd20fd1 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool1-d.pbtxt
index 000bf54..3372ae7 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool1-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool2-d.pbtxt
index 8ffbf07..0fb1882 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool2-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool3-d.pbtxt
index 3803d2b..5b1c850 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pool3-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt
index 2866822..49e59e0 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt
index b83ed67..9504f64 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt
index e689d69..42de6ae 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-input-layer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-input-layer.pbtxt
index bb6edda..f388b84 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-input-layer.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-input-layer.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
index 5fb3f9d..d2634dd 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m.pbtxt
index 8eb6dd9..94ec432 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-l-s-t-m.pbtxt
@@ -179,7 +179,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-lambda.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-lambda.pbtxt
index 376bec0..da2373c 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-lambda.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-lambda.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-layer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-layer.pbtxt
index c5f91a6..2e47132 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-layer.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-layer.pbtxt
@@ -105,7 +105,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-leaky-re-l-u.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-leaky-re-l-u.pbtxt
index bde8887..a74e935 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-leaky-re-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-leaky-re-l-u.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected1-d.pbtxt
index 16945f2..0f4c071 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected1-d.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected2-d.pbtxt
index f05741f..5eea071 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-locally-connected2-d.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-masking.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-masking.pbtxt
index 7885db4..a16ceef 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-masking.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-masking.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool1-d.pbtxt
index 9380d26..e61d730 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool1-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool2-d.pbtxt
index 8eb8218..a21c403 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool2-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool3-d.pbtxt
index 0c96f86..fb8613a 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pool3-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling1-d.pbtxt
index 0c6b230..a433d49 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling1-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling2-d.pbtxt
index eb7ca52..fa6ad6f 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling2-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling3-d.pbtxt
index e724e90..05e2ace 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-max-pooling3-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-maximum.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-maximum.pbtxt
index dafbd09..ce62223 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-maximum.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-maximum.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-minimum.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-minimum.pbtxt
index 3122fbe..a0ff4f9 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-minimum.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-minimum.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-multiply.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-multiply.pbtxt
index 0527cda..558cc0d 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-multiply.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-multiply.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-p-re-l-u.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-p-re-l-u.pbtxt
index 814e5a5..5863fbb 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-p-re-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-p-re-l-u.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-permute.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-permute.pbtxt
index aa1731a..4d7413b 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-permute.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-permute.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-r-n-n.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-r-n-n.pbtxt
index 9d7dd85..67ab60b 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-r-n-n.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-r-n-n.pbtxt
@@ -110,7 +110,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-re-l-u.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-re-l-u.pbtxt
index e9bba29..eb32ba2 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-re-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-re-l-u.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-repeat-vector.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-repeat-vector.pbtxt
index 3c783eb..81ac253 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-repeat-vector.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-repeat-vector.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-reshape.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-reshape.pbtxt
index b8e0882..dd4dc49 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-reshape.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-reshape.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv1-d.pbtxt
index 310f369..c8724f0 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv1-d.pbtxt
@@ -108,7 +108,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv2-d.pbtxt
index df19d78..8c47395 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-conv2-d.pbtxt
@@ -108,7 +108,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution1-d.pbtxt
index bf90950..c0b6ad4 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution1-d.pbtxt
@@ -108,7 +108,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution2-d.pbtxt
index 5d66bc6..c5566c1 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-separable-convolution2-d.pbtxt
@@ -108,7 +108,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
index 88e9300..f91aac8 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n.pbtxt
index 9d81c6d..eb2a7b9 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-simple-r-n-n.pbtxt
@@ -167,7 +167,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-softmax.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-softmax.pbtxt
index 712eb0c..f0411e2 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-softmax.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-softmax.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt
index dfc4ca2..2a2fd2e 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt
index 5e4f727..e4d1d43 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt
index 9d893cb..4e641a8e 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
index a2ed954..591796e 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
@@ -114,7 +114,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-subtract.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-subtract.pbtxt
index 8a0818e..67555db 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-subtract.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-subtract.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt
index b5591b4..0ed7da5 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-time-distributed.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-time-distributed.pbtxt
index 210e4fd..9492b0b 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-time-distributed.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-time-distributed.pbtxt
@@ -111,7 +111,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling1-d.pbtxt
index da2213a..16c31d3 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling1-d.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling2-d.pbtxt
index e2c303d..cf1a076 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling2-d.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling3-d.pbtxt
index 396e774..5cded98 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-up-sampling3-d.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-wrapper.pbtxt
index 8b6418d..16f3f06 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-wrapper.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-wrapper.pbtxt
@@ -110,7 +110,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding1-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding1-d.pbtxt
index e8fda4c..59997a8 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding1-d.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding2-d.pbtxt
index 50c52d2..9a327c2 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding2-d.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding3-d.pbtxt
index 84c6b78..7933868 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-zero-padding3-d.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.-layer-normalization.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.-layer-normalization.pbtxt
index 476c597..6c8faef 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.-layer-normalization.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.experimental.-layer-normalization.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.pbtxt
index ad74ab3..cc0fdab 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.pbtxt
@@ -125,6 +125,10 @@
mtype: "<type \'type\'>"
}
member {
+ name: "DenseFeatures"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "DepthwiseConv2D"
mtype: "<type \'type\'>"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt
index eb1ab1d..3132e8d 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-model.pbtxt
@@ -135,7 +135,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt
index c69cf28..b5ef70e 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.models.-sequential.pbtxt
@@ -140,7 +140,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-adjoint.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-adjoint.pbtxt
new file mode 100644
index 0000000..37344f7
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-adjoint.pbtxt
@@ -0,0 +1,150 @@
+path: "tensorflow.linalg.LinearOperatorAdjoint"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator_adjoint.LinearOperatorAdjoint\'>"
+ is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "H"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "batch_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "domain_dimension"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "graph_parents"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_non_singular"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_positive_definite"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_self_adjoint"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_square"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "operator"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "range_dimension"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "tensor_rank"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'operator\', \'is_non_singular\', \'is_self_adjoint\', \'is_positive_definite\', \'is_square\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "add_to_tensor"
+ argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
+ }
+ member_method {
+ name: "adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
+ }
+ member_method {
+ name: "assert_non_singular"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], "
+ }
+ member_method {
+ name: "assert_positive_definite"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_positive_definite\'], "
+ }
+ member_method {
+ name: "assert_self_adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_self_adjoint\'], "
+ }
+ member_method {
+ name: "batch_shape_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], "
+ }
+ member_method {
+ name: "cholesky"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'cholesky\'], "
+ }
+ member_method {
+ name: "determinant"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'det\'], "
+ }
+ member_method {
+ name: "diag_part"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'diag_part\'], "
+ }
+ member_method {
+ name: "domain_dimension_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'domain_dimension_tensor\'], "
+ }
+ member_method {
+ name: "inverse"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'inverse\'], "
+ }
+ member_method {
+ name: "log_abs_determinant"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'log_abs_det\'], "
+ }
+ member_method {
+ name: "matmul"
+ argspec: "args=[\'self\', \'x\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'matmul\'], "
+ }
+ member_method {
+ name: "matvec"
+ argspec: "args=[\'self\', \'x\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'matvec\'], "
+ }
+ member_method {
+ name: "range_dimension_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'range_dimension_tensor\'], "
+ }
+ member_method {
+ name: "shape_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'shape_tensor\'], "
+ }
+ member_method {
+ name: "solve"
+ argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'solve\'], "
+ }
+ member_method {
+ name: "solvevec"
+ argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'solve\'], "
+ }
+ member_method {
+ name: "tensor_rank_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'tensor_rank_tensor\'], "
+ }
+ member_method {
+ name: "to_dense"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'to_dense\'], "
+ }
+ member_method {
+ name: "trace"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'trace\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-block-diag.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-block-diag.pbtxt
index c7a5096..ddef774 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-block-diag.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-block-diag.pbtxt
@@ -4,6 +4,10 @@
is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
is_instance: "<type \'object\'>"
member {
+ name: "H"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "batch_shape"
mtype: "<type \'property\'>"
}
@@ -64,6 +68,10 @@
argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
}
member_method {
+ name: "adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
+ }
+ member_method {
name: "assert_non_singular"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-circulant.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-circulant.pbtxt
index 3900c75..97a6b1a 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-circulant.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-circulant.pbtxt
@@ -5,6 +5,10 @@
is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
is_instance: "<type \'object\'>"
member {
+ name: "H"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "batch_shape"
mtype: "<type \'property\'>"
}
@@ -73,6 +77,10 @@
argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
}
member_method {
+ name: "adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
+ }
+ member_method {
name: "assert_hermitian_spectrum"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_hermitian_spectrum\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-circulant2-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-circulant2-d.pbtxt
index 7b87609..e2bfe7e 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-circulant2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-circulant2-d.pbtxt
@@ -5,6 +5,10 @@
is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
is_instance: "<type \'object\'>"
member {
+ name: "H"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "batch_shape"
mtype: "<type \'property\'>"
}
@@ -73,6 +77,10 @@
argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
}
member_method {
+ name: "adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
+ }
+ member_method {
name: "assert_hermitian_spectrum"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_hermitian_spectrum\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-circulant3-d.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-circulant3-d.pbtxt
index 5bddba8..8885526 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-circulant3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-circulant3-d.pbtxt
@@ -5,6 +5,10 @@
is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
is_instance: "<type \'object\'>"
member {
+ name: "H"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "batch_shape"
mtype: "<type \'property\'>"
}
@@ -73,6 +77,10 @@
argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
}
member_method {
+ name: "adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
+ }
+ member_method {
name: "assert_hermitian_spectrum"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_hermitian_spectrum\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-composition.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-composition.pbtxt
index 62ba8bb..2a017fc 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-composition.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-composition.pbtxt
@@ -4,6 +4,10 @@
is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
is_instance: "<type \'object\'>"
member {
+ name: "H"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "batch_shape"
mtype: "<type \'property\'>"
}
@@ -64,6 +68,10 @@
argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
}
member_method {
+ name: "adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
+ }
+ member_method {
name: "assert_non_singular"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-diag.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-diag.pbtxt
index 0803fee..31dcf7b 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-diag.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-diag.pbtxt
@@ -4,6 +4,10 @@
is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
is_instance: "<type \'object\'>"
member {
+ name: "H"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "batch_shape"
mtype: "<type \'property\'>"
}
@@ -64,6 +68,10 @@
argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
}
member_method {
+ name: "adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
+ }
+ member_method {
name: "assert_non_singular"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-full-matrix.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-full-matrix.pbtxt
index 6def328..0ad39b4 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-full-matrix.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-full-matrix.pbtxt
@@ -4,6 +4,10 @@
is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
is_instance: "<type \'object\'>"
member {
+ name: "H"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "batch_shape"
mtype: "<type \'property\'>"
}
@@ -60,6 +64,10 @@
argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
}
member_method {
+ name: "adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
+ }
+ member_method {
name: "assert_non_singular"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-identity.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-identity.pbtxt
index dbf1ac8..f66a5a8 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-identity.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-identity.pbtxt
@@ -5,6 +5,10 @@
is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
is_instance: "<type \'object\'>"
member {
+ name: "H"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "batch_shape"
mtype: "<type \'property\'>"
}
@@ -61,6 +65,10 @@
argspec: "args=[\'self\', \'mat\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
}
member_method {
+ name: "adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
+ }
+ member_method {
name: "assert_non_singular"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-inversion.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-inversion.pbtxt
index 6a3fe4d..a7eb144 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-inversion.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-inversion.pbtxt
@@ -4,6 +4,10 @@
is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
is_instance: "<type \'object\'>"
member {
+ name: "H"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "batch_shape"
mtype: "<type \'property\'>"
}
@@ -64,6 +68,10 @@
argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
}
member_method {
+ name: "adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
+ }
+ member_method {
name: "assert_non_singular"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-kronecker.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-kronecker.pbtxt
index 85d902b..c983f8c 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-kronecker.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-kronecker.pbtxt
@@ -4,6 +4,10 @@
is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
is_instance: "<type \'object\'>"
member {
+ name: "H"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "batch_shape"
mtype: "<type \'property\'>"
}
@@ -64,6 +68,10 @@
argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
}
member_method {
+ name: "adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
+ }
+ member_method {
name: "assert_non_singular"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-low-rank-update.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-low-rank-update.pbtxt
index 638d82a..813aec2 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-low-rank-update.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-low-rank-update.pbtxt
@@ -4,6 +4,10 @@
is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
is_instance: "<type \'object\'>"
member {
+ name: "H"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "base_operator"
mtype: "<type \'property\'>"
}
@@ -84,6 +88,10 @@
argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
}
member_method {
+ name: "adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
+ }
+ member_method {
name: "assert_non_singular"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-lower-triangular.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-lower-triangular.pbtxt
index ab1b04b..0bb7a15 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-lower-triangular.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-lower-triangular.pbtxt
@@ -4,6 +4,10 @@
is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
is_instance: "<type \'object\'>"
member {
+ name: "H"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "batch_shape"
mtype: "<type \'property\'>"
}
@@ -60,6 +64,10 @@
argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
}
member_method {
+ name: "adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
+ }
+ member_method {
name: "assert_non_singular"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-scaled-identity.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-scaled-identity.pbtxt
index 961969a..7747c98 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-scaled-identity.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-scaled-identity.pbtxt
@@ -5,6 +5,10 @@
is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
is_instance: "<type \'object\'>"
member {
+ name: "H"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "batch_shape"
mtype: "<type \'property\'>"
}
@@ -65,6 +69,10 @@
argspec: "args=[\'self\', \'mat\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
}
member_method {
+ name: "adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
+ }
+ member_method {
name: "assert_non_singular"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-zeros.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-zeros.pbtxt
index e76738a..590782b 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-zeros.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator-zeros.pbtxt
@@ -4,6 +4,10 @@
is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
is_instance: "<type \'object\'>"
member {
+ name: "H"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "batch_shape"
mtype: "<type \'property\'>"
}
@@ -60,6 +64,10 @@
argspec: "args=[\'self\', \'mat\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
}
member_method {
+ name: "adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
+ }
+ member_method {
name: "assert_non_singular"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator.pbtxt
index b35cd69..ed6bfdf 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.linalg.-linear-operator.pbtxt
@@ -3,6 +3,10 @@
is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
is_instance: "<type \'object\'>"
member {
+ name: "H"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "batch_shape"
mtype: "<type \'property\'>"
}
@@ -59,6 +63,10 @@
argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
}
member_method {
+ name: "adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
+ }
+ member_method {
name: "assert_non_singular"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt
index 5e49b75..973850c 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt
@@ -5,6 +5,10 @@
mtype: "<type \'type\'>"
}
member {
+ name: "LinearOperatorAdjoint"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "LinearOperatorBlockDiag"
mtype: "<type \'type\'>"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.lite.-op-hint.-op-hint-argument-tracker.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.lite.-op-hint.-op-hint-argument-tracker.pbtxt
index 1fe179f..68cb07e 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.lite.-op-hint.-op-hint-argument-tracker.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.lite.-op-hint.-op-hint-argument-tracker.pbtxt
@@ -4,7 +4,7 @@
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'function_name\', \'unique_function_id\', \'node_name_prefix\', \'attr_name\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'function_name\', \'unique_function_id\', \'node_name_prefix\', \'attr_name\', \'level\', \'children_inputs_mappings\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
}
member_method {
name: "add"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.lite.-op-hint.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.lite.-op-hint.pbtxt
index 66e692a..3ac478f 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.lite.-op-hint.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.lite.-op-hint.pbtxt
@@ -15,6 +15,10 @@
mtype: "<type \'str\'>"
}
member {
+ name: "CHILDREN_INPUTS_MAPPINGS"
+ mtype: "<type \'str\'>"
+ }
+ member {
name: "FUNCTION_AGGREGATE_ATTR"
mtype: "<type \'str\'>"
}
@@ -23,6 +27,10 @@
mtype: "<type \'str\'>"
}
member {
+ name: "FUNCTION_LEVEL_ATTR"
+ mtype: "<type \'str\'>"
+ }
+ member {
name: "FUNCTION_NAME_ATTR"
mtype: "<type \'str\'>"
}
@@ -48,7 +56,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'function_name\'], varargs=None, keywords=kwargs, defaults=None"
+ argspec: "args=[\'self\', \'function_name\', \'level\', \'children_inputs_mappings\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'None\'], "
}
member_method {
name: "add_input"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.pbtxt
index 4a55224..b8a56e1 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.nn.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.pbtxt
@@ -54,7 +54,7 @@
}
member_method {
name: "conv1d"
- argspec: "args=[\'value\', \'filters\', \'stride\', \'padding\', \'use_cudnn_on_gpu\', \'data_format\', \'name\', \'input\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ argspec: "args=[\'value\', \'filters\', \'stride\', \'padding\', \'use_cudnn_on_gpu\', \'data_format\', \'name\', \'input\', \'dilations\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "conv2d"
@@ -234,19 +234,31 @@
}
member_method {
name: "max_pool"
- argspec: "args=[\'value\', \'ksize\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'None\'], "
+ argspec: "args=[\'value\', \'ksize\', \'strides\', \'padding\', \'data_format\', \'name\', \'input\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "max_pool1d"
+ argspec: "args=[\'input\', \'ksize\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NWC\', \'None\'], "
+ }
+ member_method {
+ name: "max_pool2d"
+ argspec: "args=[\'input\', \'ksize\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'None\'], "
}
member_method {
name: "max_pool3d"
argspec: "args=[\'input\', \'ksize\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NDHWC\', \'None\'], "
}
member_method {
+ name: "max_pool_v2"
+ argspec: "args=[\'input\', \'ksize\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
name: "max_pool_with_argmax"
- argspec: "args=[\'input\', \'ksize\', \'strides\', \'padding\', \'data_format\', \'Targmax\', \'name\', \'output_dtype\'], varargs=None, keywords=None, defaults=[\'NHWC\', \"<dtype: \'int64\'>\", \'None\', \'None\'], "
+ argspec: "args=[\'input\', \'ksize\', \'strides\', \'padding\', \'data_format\', \'Targmax\', \'name\', \'output_dtype\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "moments"
- argspec: "args=[\'x\', \'axes\', \'shift\', \'name\', \'keep_dims\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\'], "
+ argspec: "args=[\'x\', \'axes\', \'shift\', \'name\', \'keep_dims\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "nce_loss"
@@ -362,7 +374,7 @@
}
member_method {
name: "sufficient_statistics"
- argspec: "args=[\'x\', \'axes\', \'shift\', \'keep_dims\', \'name\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'None\'], "
+ argspec: "args=[\'x\', \'axes\', \'shift\', \'keep_dims\', \'name\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
name: "tanh"
@@ -382,7 +394,7 @@
}
member_method {
name: "weighted_moments"
- argspec: "args=[\'x\', \'axes\', \'frequency_weights\', \'name\', \'keep_dims\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'x\', \'axes\', \'frequency_weights\', \'name\', \'keep_dims\', \'keepdims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "with_space_to_batch"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt
index 14cf5ce..f273e11 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-dropout-wrapper.pbtxt
@@ -1,6 +1,7 @@
path: "tensorflow.nn.rnn_cell.DropoutWrapper"
tf_class {
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.DropoutWrapper\'>"
+ is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl._RNNCellWrapperV1\'>"
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt
index 9de7307..4003e87 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt
@@ -1,6 +1,7 @@
path: "tensorflow.nn.rnn_cell.ResidualWrapper"
tf_class {
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.ResidualWrapper\'>"
+ is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl._RNNCellWrapperV1\'>"
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
index 56d0605..64f5720 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
@@ -1461,6 +1461,10 @@
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "is_tensor"
+ argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "is_variable_initialized"
argspec: "args=[\'variable\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
index c52a0cd..28f26fe 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt
@@ -1097,6 +1097,10 @@
argspec: "args=[\'seed\', \'seed2\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "ExperimentalRebatchDataset"
+ argspec: "args=[\'input_dataset\', \'num_workers\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "ExperimentalScanDataset"
argspec: "args=[\'input_dataset\', \'initial_state\', \'other_arguments\', \'f\', \'output_types\', \'output_shapes\', \'preserve_cardinality\'], varargs=None, keywords=None, defaults=None"
}
@@ -3053,6 +3057,10 @@
argspec: "args=[\'input\', \'out_type\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "ShardDataset"
+ argspec: "args=[\'input_dataset\', \'num_shards\', \'index\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "ShardedFilename"
argspec: "args=[\'basename\', \'shard\', \'num_shards\'], varargs=None, keywords=None, defaults=None"
}
@@ -3717,6 +3725,10 @@
argspec: "args=[\'input_a\', \'input_b\', \'element_dtype\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "TensorListConcatV2"
+ argspec: "args=[\'input_handle\', \'element_shape\', \'leading_dims\', \'element_dtype\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "TensorListElementShape"
argspec: "args=[\'input_handle\', \'shape_type\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-aggregation-method.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-aggregation-method.pbtxt
index f79029d..cc2d5c8 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-aggregation-method.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-aggregation-method.pbtxt
@@ -1,6 +1,6 @@
path: "tensorflow.AggregationMethod"
tf_class {
- is_instance: "<class \'tensorflow.python.ops.gradients_impl.AggregationMethod\'>"
+ is_instance: "<class \'tensorflow.python.ops.gradients_util.AggregationMethod\'>"
is_instance: "<type \'object\'>"
member {
name: "ADD_N"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.audio.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.audio.pbtxt
index ce29615..6c57240 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.audio.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.audio.pbtxt
@@ -1,6 +1,10 @@
path: "tensorflow.audio"
tf_module {
member_method {
+ name: "decode_wav"
+ argspec: "args=[\'contents\', \'desired_channels\', \'desired_samples\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'-1\', \'None\'], "
+ }
+ member_method {
name: "encode_wav"
argspec: "args=[\'audio\', \'sample_rate\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
index 951b2df..195c104 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
@@ -91,6 +91,10 @@
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "shard"
+ argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "shuffle"
argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
index f157351..043584c 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
@@ -93,6 +93,10 @@
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "shard"
+ argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "shuffle"
argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
index 690da98..76d15f4 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
@@ -92,6 +92,10 @@
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "shard"
+ argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "shuffle"
argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
index fe0bc1a..a6c7a2d 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
@@ -93,6 +93,10 @@
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "shard"
+ argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "shuffle"
argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt
index 261129b..ae177eb 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt
@@ -93,6 +93,10 @@
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "shard"
+ argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "shuffle"
argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt
index 0b34bbc..c15c73f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt
@@ -93,6 +93,10 @@
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "shard"
+ argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "shuffle"
argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt
index 0e61890..567f48d 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt
@@ -93,6 +93,10 @@
argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "shard"
+ argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "shuffle"
argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.distribute.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.distribute.pbtxt
index 5b5c9e2..0d0d350 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.distribute.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.distribute.pbtxt
@@ -37,10 +37,6 @@
mtype: "<type \'module\'>"
}
member_method {
- name: "get_loss_reduction"
- argspec: "args=[], varargs=None, keywords=None, defaults=None"
- }
- member_method {
name: "get_replica_context"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.experimental.-module.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.experimental.-module.pbtxt
index c364b02..3c5add1 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.experimental.-module.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.experimental.-module.pbtxt
@@ -13,18 +13,6 @@
mtype: "<type \'property\'>"
}
member {
- name: "owned_submodules"
- mtype: "<type \'property\'>"
- }
- member {
- name: "owned_trainable_variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "owned_variables"
- mtype: "<type \'property\'>"
- }
- member {
name: "submodules"
mtype: "<type \'property\'>"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt
index 283cc6a..bb44ba0 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.-model.pbtxt
@@ -135,7 +135,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt
index 95e405a..44fc15e 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.-sequential.pbtxt
@@ -140,7 +140,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-peephole-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-peephole-l-s-t-m-cell.pbtxt
index 1bfd51c..4f7ace4 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-peephole-l-s-t-m-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.experimental.-peephole-l-s-t-m-cell.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.constant.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.constant.pbtxt
new file mode 100644
index 0000000..b03cbb8
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.constant.pbtxt
@@ -0,0 +1,18 @@
+path: "tensorflow.keras.initializers.constant"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Constant\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'value\'], varargs=None, keywords=None, defaults=[\'0\'], "
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_normal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_normal.pbtxt
new file mode 100644
index 0000000..02f8c25
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_normal.pbtxt
@@ -0,0 +1,19 @@
+path: "tensorflow.keras.initializers.glorot_normal"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.init_ops_v2.GlorotNormal\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops_v2.VarianceScaling\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_uniform.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_uniform.pbtxt
new file mode 100644
index 0000000..6d18a3b
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.glorot_uniform.pbtxt
@@ -0,0 +1,19 @@
+path: "tensorflow.keras.initializers.glorot_uniform"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.init_ops_v2.GlorotUniform\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops_v2.VarianceScaling\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.identity.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.identity.pbtxt
new file mode 100644
index 0000000..dcdb6dd
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.identity.pbtxt
@@ -0,0 +1,18 @@
+path: "tensorflow.keras.initializers.identity"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Identity\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'gain\'], varargs=None, keywords=None, defaults=[\'1.0\'], "
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.ones.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.ones.pbtxt
new file mode 100644
index 0000000..cc2dd17
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.ones.pbtxt
@@ -0,0 +1,17 @@
+path: "tensorflow.keras.initializers.ones"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Ones\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.orthogonal.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.orthogonal.pbtxt
new file mode 100644
index 0000000..855065c
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.orthogonal.pbtxt
@@ -0,0 +1,18 @@
+path: "tensorflow.keras.initializers.orthogonal"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Orthogonal\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'gain\', \'seed\'], varargs=None, keywords=None, defaults=[\'1.0\', \'None\'], "
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.pbtxt
index 7412cd1..15a56fb 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.pbtxt
@@ -48,6 +48,34 @@
name: "Zeros"
mtype: "<type \'type\'>"
}
+ member {
+ name: "constant"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "glorot_normal"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "glorot_uniform"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "identity"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "ones"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "orthogonal"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "zeros"
+ mtype: "<type \'type\'>"
+ }
member_method {
name: "deserialize"
argspec: "args=[\'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.zeros.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.zeros.pbtxt
new file mode 100644
index 0000000..f9b3359
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.initializers.zeros.pbtxt
@@ -0,0 +1,17 @@
+path: "tensorflow.keras.initializers.zeros"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Zeros\'>"
+ is_instance: "<class \'tensorflow.python.ops.init_ops_v2.Initializer\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ }
+ member_method {
+ name: "from_config"
+ argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_config"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activation.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activation.pbtxt
index 8a0b8eb..eab888c 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activation.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activation.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activity-regularization.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activity-regularization.pbtxt
index abb3c23..96c7acc 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activity-regularization.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-activity-regularization.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-add.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-add.pbtxt
index b27db4e..9e8aae1 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-add.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-add.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-alpha-dropout.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-alpha-dropout.pbtxt
index 50998ac..01fc730 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-alpha-dropout.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-alpha-dropout.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling1-d.pbtxt
index be17aea..8b6a151 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling1-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling2-d.pbtxt
index 7f21b44..3c78457 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling2-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling3-d.pbtxt
index 2ac86f1..e6e96a0 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average-pooling3-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average.pbtxt
index f6b1dd2..ec2d5b1 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-average.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool1-d.pbtxt
index 3da1f43..afff790 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool1-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool2-d.pbtxt
index a7be5ac..d7ab835 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool2-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool3-d.pbtxt
index c5c29be..6654f86 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-avg-pool3-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-batch-normalization.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-batch-normalization.pbtxt
index b13f963..05ac793 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-batch-normalization.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-batch-normalization.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-bidirectional.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-bidirectional.pbtxt
index 880d18e..94f3a46 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-bidirectional.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-bidirectional.pbtxt
@@ -115,7 +115,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-concatenate.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-concatenate.pbtxt
index 1eb0cf1..e0eae17 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-concatenate.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-concatenate.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt
index d9394e6..ec8a44c 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv-l-s-t-m2-d.pbtxt
@@ -196,7 +196,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv1-d.pbtxt
index a0f6dc8..350d49a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv1-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d-transpose.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d-transpose.pbtxt
index 037b92f..9b48eb6 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d-transpose.pbtxt
@@ -108,7 +108,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d.pbtxt
index 6a0d027..1708d6a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv2-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
index 66b5bd7..5018492 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d-transpose.pbtxt
@@ -108,7 +108,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d.pbtxt
index e73133f..fd24af3 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-conv3-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution1-d.pbtxt
index 7af6b2b..fbc7609 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution1-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt
index baff492..671a004 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d-transpose.pbtxt
@@ -108,7 +108,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d.pbtxt
index 63d30a6..dd6519c 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution2-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
index 7a29cbb..648f480 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d-transpose.pbtxt
@@ -108,7 +108,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d.pbtxt
index 87c75c0..87a07ea 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-convolution3-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping1-d.pbtxt
index f69104d..6f3a153 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping1-d.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping2-d.pbtxt
index aa05471..a1c418c 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping2-d.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping3-d.pbtxt
index d61f1dd..ad98f9c 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-cropping3-d.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense-features.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense-features.pbtxt
index 3caa3ff..ca6a327 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense-features.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense-features.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense.pbtxt
index 06e8b6b..ef12b2e 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dense.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt
index 9fdf6f6..eacfb37 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-depthwise-conv2-d.pbtxt
@@ -108,7 +108,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dot.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dot.pbtxt
index cbe1020..7928ceb 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dot.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dot.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dropout.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dropout.pbtxt
index 0efba09..a7fa545 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dropout.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-dropout.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-e-l-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-e-l-u.pbtxt
index b34c499..483ba65 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-e-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-e-l-u.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-embedding.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-embedding.pbtxt
index 51dd853..4d0e5e1 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-embedding.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-embedding.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-flatten.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-flatten.pbtxt
index dcd18a9..5947047 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-flatten.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-flatten.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u-cell.pbtxt
index f029907..b4efdf3 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u-cell.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u.pbtxt
index ac2d8c9..811990a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-g-r-u.pbtxt
@@ -180,7 +180,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-dropout.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-dropout.pbtxt
index 15cbcfe..1686768 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-dropout.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-dropout.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-noise.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-noise.pbtxt
index 865b898..69bca6a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-noise.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-gaussian-noise.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt
index 3e17aca..9a4119d 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling1-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt
index b160687..2ca1eb1 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling2-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt
index 70e8d51..4331adc 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-average-pooling3-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt
index 809dc85..6e91b4a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool1-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt
index 3fbce8c..85887a5 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool2-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt
index 70e4103..dd20fd1 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-avg-pool3-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool1-d.pbtxt
index 000bf54..3372ae7 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool1-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool2-d.pbtxt
index 8ffbf07..0fb1882 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool2-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool3-d.pbtxt
index 3803d2b..5b1c850 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pool3-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt
index 2866822..49e59e0 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling1-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt
index b83ed67..9504f64 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling2-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt
index e689d69..42de6ae 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-global-max-pooling3-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-input-layer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-input-layer.pbtxt
index bb6edda..f388b84 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-input-layer.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-input-layer.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
index 5fb3f9d..d2634dd 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m-cell.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m.pbtxt
index 89dfc2a..622a8e2 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-l-s-t-m.pbtxt
@@ -180,7 +180,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-lambda.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-lambda.pbtxt
index 376bec0..da2373c 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-lambda.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-lambda.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-layer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-layer.pbtxt
index c5f91a6..2e47132 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-layer.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-layer.pbtxt
@@ -105,7 +105,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-leaky-re-l-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-leaky-re-l-u.pbtxt
index bde8887..a74e935 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-leaky-re-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-leaky-re-l-u.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-linear-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-linear-model.pbtxt
index c4726cf..bf6c84b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-linear-model.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-linear-model.pbtxt
@@ -140,7 +140,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected1-d.pbtxt
index 16945f2..0f4c071 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected1-d.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected2-d.pbtxt
index f05741f..5eea071 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-locally-connected2-d.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-masking.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-masking.pbtxt
index 7885db4..a16ceef 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-masking.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-masking.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool1-d.pbtxt
index 9380d26..e61d730 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool1-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool2-d.pbtxt
index 8eb8218..a21c403 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool2-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool3-d.pbtxt
index 0c96f86..fb8613a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pool3-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling1-d.pbtxt
index 0c6b230..a433d49 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling1-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling2-d.pbtxt
index eb7ca52..fa6ad6f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling2-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling3-d.pbtxt
index e724e90..05e2ace 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-max-pooling3-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-maximum.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-maximum.pbtxt
index dafbd09..ce62223 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-maximum.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-maximum.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-minimum.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-minimum.pbtxt
index 3122fbe..a0ff4f9 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-minimum.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-minimum.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-multiply.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-multiply.pbtxt
index 0527cda..558cc0d 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-multiply.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-multiply.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-p-re-l-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-p-re-l-u.pbtxt
index 814e5a5..5863fbb 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-p-re-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-p-re-l-u.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-permute.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-permute.pbtxt
index aa1731a..4d7413b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-permute.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-permute.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-r-n-n.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-r-n-n.pbtxt
index 9d7dd85..67ab60b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-r-n-n.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-r-n-n.pbtxt
@@ -110,7 +110,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-re-l-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-re-l-u.pbtxt
index e9bba29..eb32ba2 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-re-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-re-l-u.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-repeat-vector.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-repeat-vector.pbtxt
index 3c783eb..81ac253 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-repeat-vector.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-repeat-vector.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-reshape.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-reshape.pbtxt
index b8e0882..dd4dc49 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-reshape.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-reshape.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv1-d.pbtxt
index 310f369..c8724f0 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv1-d.pbtxt
@@ -108,7 +108,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv2-d.pbtxt
index df19d78..8c47395 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-conv2-d.pbtxt
@@ -108,7 +108,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution1-d.pbtxt
index bf90950..c0b6ad4 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution1-d.pbtxt
@@ -108,7 +108,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution2-d.pbtxt
index 5d66bc6..c5566c1 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-separable-convolution2-d.pbtxt
@@ -108,7 +108,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
index 88e9300..f91aac8 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n-cell.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n.pbtxt
index 9d81c6d..eb2a7b9 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-simple-r-n-n.pbtxt
@@ -167,7 +167,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-softmax.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-softmax.pbtxt
index 712eb0c..f0411e2 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-softmax.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-softmax.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt
index dfc4ca2..2a2fd2e 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout1-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt
index 5e4f727..e4d1d43 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout2-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt
index 9d893cb..4e641a8e 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-spatial-dropout3-d.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
index a2ed954..591796e 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-stacked-r-n-n-cells.pbtxt
@@ -114,7 +114,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-subtract.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-subtract.pbtxt
index 8a0818e..67555db 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-subtract.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-subtract.pbtxt
@@ -107,7 +107,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt
index b5591b4..0ed7da5 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-thresholded-re-l-u.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-time-distributed.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-time-distributed.pbtxt
index 210e4fd..9492b0b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-time-distributed.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-time-distributed.pbtxt
@@ -111,7 +111,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling1-d.pbtxt
index da2213a..16c31d3 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling1-d.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling2-d.pbtxt
index e2c303d..cf1a076 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling2-d.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling3-d.pbtxt
index 396e774..5cded98 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-up-sampling3-d.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-wrapper.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-wrapper.pbtxt
index 8b6418d..16f3f06 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-wrapper.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-wrapper.pbtxt
@@ -110,7 +110,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding1-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding1-d.pbtxt
index e8fda4c..59997a8 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding1-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding1-d.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding2-d.pbtxt
index 50c52d2..9a327c2 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding2-d.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding3-d.pbtxt
index 84c6b78..7933868 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-zero-padding3-d.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.-layer-normalization.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.-layer-normalization.pbtxt
index 476c597..6c8faef 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.-layer-normalization.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.experimental.-layer-normalization.pbtxt
@@ -106,7 +106,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt
index eb1ab1d..3132e8d 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-model.pbtxt
@@ -135,7 +135,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt
index c69cf28..b5ef70e 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.models.-sequential.pbtxt
@@ -140,7 +140,7 @@
}
member_method {
name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
+ argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\'], "
}
member_method {
name: "apply"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-adjoint.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-adjoint.pbtxt
new file mode 100644
index 0000000..37344f7
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-adjoint.pbtxt
@@ -0,0 +1,150 @@
+path: "tensorflow.linalg.LinearOperatorAdjoint"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator_adjoint.LinearOperatorAdjoint\'>"
+ is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "H"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "batch_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "domain_dimension"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "graph_parents"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_non_singular"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_positive_definite"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_self_adjoint"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_square"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "operator"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "range_dimension"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "tensor_rank"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'operator\', \'is_non_singular\', \'is_self_adjoint\', \'is_positive_definite\', \'is_square\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "add_to_tensor"
+ argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
+ }
+ member_method {
+ name: "adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
+ }
+ member_method {
+ name: "assert_non_singular"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], "
+ }
+ member_method {
+ name: "assert_positive_definite"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_positive_definite\'], "
+ }
+ member_method {
+ name: "assert_self_adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_self_adjoint\'], "
+ }
+ member_method {
+ name: "batch_shape_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], "
+ }
+ member_method {
+ name: "cholesky"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'cholesky\'], "
+ }
+ member_method {
+ name: "determinant"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'det\'], "
+ }
+ member_method {
+ name: "diag_part"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'diag_part\'], "
+ }
+ member_method {
+ name: "domain_dimension_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'domain_dimension_tensor\'], "
+ }
+ member_method {
+ name: "inverse"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'inverse\'], "
+ }
+ member_method {
+ name: "log_abs_determinant"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'log_abs_det\'], "
+ }
+ member_method {
+ name: "matmul"
+ argspec: "args=[\'self\', \'x\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'matmul\'], "
+ }
+ member_method {
+ name: "matvec"
+ argspec: "args=[\'self\', \'x\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'matvec\'], "
+ }
+ member_method {
+ name: "range_dimension_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'range_dimension_tensor\'], "
+ }
+ member_method {
+ name: "shape_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'shape_tensor\'], "
+ }
+ member_method {
+ name: "solve"
+ argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'solve\'], "
+ }
+ member_method {
+ name: "solvevec"
+ argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'solve\'], "
+ }
+ member_method {
+ name: "tensor_rank_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'tensor_rank_tensor\'], "
+ }
+ member_method {
+ name: "to_dense"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'to_dense\'], "
+ }
+ member_method {
+ name: "trace"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'trace\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-block-diag.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-block-diag.pbtxt
index c7a5096..ddef774 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-block-diag.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-block-diag.pbtxt
@@ -4,6 +4,10 @@
is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
is_instance: "<type \'object\'>"
member {
+ name: "H"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "batch_shape"
mtype: "<type \'property\'>"
}
@@ -64,6 +68,10 @@
argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
}
member_method {
+ name: "adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
+ }
+ member_method {
name: "assert_non_singular"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant.pbtxt
index 3900c75..97a6b1a 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant.pbtxt
@@ -5,6 +5,10 @@
is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
is_instance: "<type \'object\'>"
member {
+ name: "H"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "batch_shape"
mtype: "<type \'property\'>"
}
@@ -73,6 +77,10 @@
argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
}
member_method {
+ name: "adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
+ }
+ member_method {
name: "assert_hermitian_spectrum"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_hermitian_spectrum\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant2-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant2-d.pbtxt
index 7b87609..e2bfe7e 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant2-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant2-d.pbtxt
@@ -5,6 +5,10 @@
is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
is_instance: "<type \'object\'>"
member {
+ name: "H"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "batch_shape"
mtype: "<type \'property\'>"
}
@@ -73,6 +77,10 @@
argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
}
member_method {
+ name: "adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
+ }
+ member_method {
name: "assert_hermitian_spectrum"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_hermitian_spectrum\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant3-d.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant3-d.pbtxt
index 5bddba8..8885526 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant3-d.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-circulant3-d.pbtxt
@@ -5,6 +5,10 @@
is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
is_instance: "<type \'object\'>"
member {
+ name: "H"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "batch_shape"
mtype: "<type \'property\'>"
}
@@ -73,6 +77,10 @@
argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
}
member_method {
+ name: "adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
+ }
+ member_method {
name: "assert_hermitian_spectrum"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_hermitian_spectrum\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-composition.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-composition.pbtxt
index 62ba8bb..2a017fc 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-composition.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-composition.pbtxt
@@ -4,6 +4,10 @@
is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
is_instance: "<type \'object\'>"
member {
+ name: "H"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "batch_shape"
mtype: "<type \'property\'>"
}
@@ -64,6 +68,10 @@
argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
}
member_method {
+ name: "adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
+ }
+ member_method {
name: "assert_non_singular"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-diag.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-diag.pbtxt
index 0803fee..31dcf7b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-diag.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-diag.pbtxt
@@ -4,6 +4,10 @@
is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
is_instance: "<type \'object\'>"
member {
+ name: "H"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "batch_shape"
mtype: "<type \'property\'>"
}
@@ -64,6 +68,10 @@
argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
}
member_method {
+ name: "adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
+ }
+ member_method {
name: "assert_non_singular"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-full-matrix.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-full-matrix.pbtxt
index 6def328..0ad39b4 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-full-matrix.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-full-matrix.pbtxt
@@ -4,6 +4,10 @@
is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
is_instance: "<type \'object\'>"
member {
+ name: "H"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "batch_shape"
mtype: "<type \'property\'>"
}
@@ -60,6 +64,10 @@
argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
}
member_method {
+ name: "adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
+ }
+ member_method {
name: "assert_non_singular"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-identity.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-identity.pbtxt
index dbf1ac8..f66a5a8 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-identity.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-identity.pbtxt
@@ -5,6 +5,10 @@
is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
is_instance: "<type \'object\'>"
member {
+ name: "H"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "batch_shape"
mtype: "<type \'property\'>"
}
@@ -61,6 +65,10 @@
argspec: "args=[\'self\', \'mat\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
}
member_method {
+ name: "adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
+ }
+ member_method {
name: "assert_non_singular"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-inversion.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-inversion.pbtxt
index 6a3fe4d..a7eb144 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-inversion.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-inversion.pbtxt
@@ -4,6 +4,10 @@
is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
is_instance: "<type \'object\'>"
member {
+ name: "H"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "batch_shape"
mtype: "<type \'property\'>"
}
@@ -64,6 +68,10 @@
argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
}
member_method {
+ name: "adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
+ }
+ member_method {
name: "assert_non_singular"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-kronecker.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-kronecker.pbtxt
index 85d902b..c983f8c 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-kronecker.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-kronecker.pbtxt
@@ -4,6 +4,10 @@
is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
is_instance: "<type \'object\'>"
member {
+ name: "H"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "batch_shape"
mtype: "<type \'property\'>"
}
@@ -64,6 +68,10 @@
argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
}
member_method {
+ name: "adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
+ }
+ member_method {
name: "assert_non_singular"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-low-rank-update.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-low-rank-update.pbtxt
index 638d82a..813aec2 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-low-rank-update.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-low-rank-update.pbtxt
@@ -4,6 +4,10 @@
is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
is_instance: "<type \'object\'>"
member {
+ name: "H"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "base_operator"
mtype: "<type \'property\'>"
}
@@ -84,6 +88,10 @@
argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
}
member_method {
+ name: "adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
+ }
+ member_method {
name: "assert_non_singular"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-lower-triangular.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-lower-triangular.pbtxt
index ab1b04b..0bb7a15 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-lower-triangular.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-lower-triangular.pbtxt
@@ -4,6 +4,10 @@
is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
is_instance: "<type \'object\'>"
member {
+ name: "H"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "batch_shape"
mtype: "<type \'property\'>"
}
@@ -60,6 +64,10 @@
argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
}
member_method {
+ name: "adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
+ }
+ member_method {
name: "assert_non_singular"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-scaled-identity.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-scaled-identity.pbtxt
index 961969a..7747c98 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-scaled-identity.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-scaled-identity.pbtxt
@@ -5,6 +5,10 @@
is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
is_instance: "<type \'object\'>"
member {
+ name: "H"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "batch_shape"
mtype: "<type \'property\'>"
}
@@ -65,6 +69,10 @@
argspec: "args=[\'self\', \'mat\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
}
member_method {
+ name: "adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
+ }
+ member_method {
name: "assert_non_singular"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-zeros.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-zeros.pbtxt
index e76738a..590782b 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-zeros.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator-zeros.pbtxt
@@ -4,6 +4,10 @@
is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
is_instance: "<type \'object\'>"
member {
+ name: "H"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "batch_shape"
mtype: "<type \'property\'>"
}
@@ -60,6 +64,10 @@
argspec: "args=[\'self\', \'mat\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
}
member_method {
+ name: "adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
+ }
+ member_method {
name: "assert_non_singular"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator.pbtxt
index b35cd69..ed6bfdf 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.-linear-operator.pbtxt
@@ -3,6 +3,10 @@
is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
is_instance: "<type \'object\'>"
member {
+ name: "H"
+ mtype: "<type \'property\'>"
+ }
+ member {
name: "batch_shape"
mtype: "<type \'property\'>"
}
@@ -59,6 +63,10 @@
argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
}
member_method {
+ name: "adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'adjoint\'], "
+ }
+ member_method {
name: "assert_non_singular"
argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt
index f9119cd..08b928e 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt
@@ -5,6 +5,10 @@
mtype: "<type \'type\'>"
}
member {
+ name: "LinearOperatorAdjoint"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "LinearOperatorBlockDiag"
mtype: "<type \'type\'>"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.lite.-op-hint.-op-hint-argument-tracker.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.lite.-op-hint.-op-hint-argument-tracker.pbtxt
index 1fe179f..68cb07e 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.lite.-op-hint.-op-hint-argument-tracker.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.lite.-op-hint.-op-hint-argument-tracker.pbtxt
@@ -4,7 +4,7 @@
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'function_name\', \'unique_function_id\', \'node_name_prefix\', \'attr_name\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'function_name\', \'unique_function_id\', \'node_name_prefix\', \'attr_name\', \'level\', \'children_inputs_mappings\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
}
member_method {
name: "add"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.lite.-op-hint.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.lite.-op-hint.pbtxt
index 66e692a..3ac478f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.lite.-op-hint.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.lite.-op-hint.pbtxt
@@ -15,6 +15,10 @@
mtype: "<type \'str\'>"
}
member {
+ name: "CHILDREN_INPUTS_MAPPINGS"
+ mtype: "<type \'str\'>"
+ }
+ member {
name: "FUNCTION_AGGREGATE_ATTR"
mtype: "<type \'str\'>"
}
@@ -23,6 +27,10 @@
mtype: "<type \'str\'>"
}
member {
+ name: "FUNCTION_LEVEL_ATTR"
+ mtype: "<type \'str\'>"
+ }
+ member {
name: "FUNCTION_NAME_ATTR"
mtype: "<type \'str\'>"
}
@@ -48,7 +56,7 @@
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'function_name\'], varargs=None, keywords=kwargs, defaults=None"
+ argspec: "args=[\'self\', \'function_name\', \'level\', \'children_inputs_mappings\'], varargs=None, keywords=kwargs, defaults=[\'1\', \'None\'], "
}
member_method {
name: "add_input"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.pbtxt
index c75c75f..930cce3 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.pbtxt
@@ -50,7 +50,7 @@
}
member_method {
name: "conv1d"
- argspec: "args=[\'input\', \'filters\', \'stride\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ argspec: "args=[\'input\', \'filters\', \'stride\', \'padding\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'NWC\', \'None\', \'None\'], "
}
member_method {
name: "conv2d"
@@ -194,7 +194,15 @@
}
member_method {
name: "max_pool"
- argspec: "args=[\'value\', \'ksize\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'None\'], "
+ argspec: "args=[\'input\', \'ksize\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "max_pool1d"
+ argspec: "args=[\'input\', \'ksize\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NWC\', \'None\'], "
+ }
+ member_method {
+ name: "max_pool2d"
+ argspec: "args=[\'input\', \'ksize\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'None\'], "
}
member_method {
name: "max_pool3d"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.pbtxt
index e2496df..9537224 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.pbtxt
@@ -12,8 +12,4 @@
name: "RNNCell"
mtype: "<type \'type\'>"
}
- member {
- name: "ResidualWrapper"
- mtype: "<type \'type\'>"
- }
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
index bb7ea2e..f020682 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
@@ -669,6 +669,10 @@
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "is_tensor"
+ argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "less"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
index c52a0cd..28f26fe 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt
@@ -1097,6 +1097,10 @@
argspec: "args=[\'seed\', \'seed2\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "ExperimentalRebatchDataset"
+ argspec: "args=[\'input_dataset\', \'num_workers\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "ExperimentalScanDataset"
argspec: "args=[\'input_dataset\', \'initial_state\', \'other_arguments\', \'f\', \'output_types\', \'output_shapes\', \'preserve_cardinality\'], varargs=None, keywords=None, defaults=None"
}
@@ -3053,6 +3057,10 @@
argspec: "args=[\'input\', \'out_type\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "ShardDataset"
+ argspec: "args=[\'input_dataset\', \'num_shards\', \'index\', \'output_types\', \'output_shapes\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "ShardedFilename"
argspec: "args=[\'basename\', \'shard\', \'num_shards\'], varargs=None, keywords=None, defaults=None"
}
@@ -3717,6 +3725,10 @@
argspec: "args=[\'input_a\', \'input_b\', \'element_dtype\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "TensorListConcatV2"
+ argspec: "args=[\'input_handle\', \'element_shape\', \'leading_dims\', \'element_dtype\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "TensorListElementShape"
argspec: "args=[\'input_handle\', \'shape_type\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.rnn.-dropout-wrapper.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.rnn.-dropout-wrapper.pbtxt
index 7781337..9f6ce04 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.rnn.-dropout-wrapper.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.rnn.-dropout-wrapper.pbtxt
@@ -1,8 +1,10 @@
path: "tensorflow.rnn.DropoutWrapper"
tf_class {
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.DropoutWrapperV2\'>"
+ is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl._RNNCellWrapperV2\'>"
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.LayerRNNCell\'>"
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.DropoutWrapper\'>"
+ is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl._RNNCellWrapperV1\'>"
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.rnn.-residual-wrapper.pbtxt
similarity index 91%
rename from tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt
rename to tensorflow/tools/api/golden/v2/tensorflow.rnn.-residual-wrapper.pbtxt
index 9de7307..51dc8c1 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-residual-wrapper.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.rnn.-residual-wrapper.pbtxt
@@ -1,6 +1,10 @@
-path: "tensorflow.nn.rnn_cell.ResidualWrapper"
+path: "tensorflow.rnn.ResidualWrapper"
tf_class {
+ is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.ResidualWrapperV2\'>"
+ is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl._RNNCellWrapperV2\'>"
+ is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.LayerRNNCell\'>"
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.ResidualWrapper\'>"
+ is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl._RNNCellWrapperV1\'>"
is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
@@ -132,11 +136,11 @@
}
member_method {
name: "build"
- argspec: "args=[\'self\', \'_\'], varargs=None, keywords=None, defaults=None"
+ argspec: "args=[\'self\', \'inputs_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "call"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=kwargs, defaults=None"
+ argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=kwargs, defaults=None"
}
member_method {
name: "compute_mask"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.rnn.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.rnn.pbtxt
index 42b1353..32be6c7 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.rnn.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.rnn.pbtxt
@@ -4,4 +4,8 @@
name: "DropoutWrapper"
mtype: "<type \'type\'>"
}
+ member {
+ name: "ResidualWrapper"
+ mtype: "<type \'type\'>"
+ }
}
diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh
index 2c348a0..afb2827 100755
--- a/tensorflow/tools/ci_build/ci_sanity.sh
+++ b/tensorflow/tools/ci_build/ci_sanity.sh
@@ -540,9 +540,12 @@
python file_name_test.py
}
+do_libtensorflow_framework_not_depend_on_cuda_check() {
+ bazel build --action_env=TF_NEED_CUDA=1 --define framework_shared_object=true --config=cuda --nobuild_tests_only tensorflow/core/platform/default/build_config:libtensorflow_cuda_check_deps
+}
# Supply all sanity step commands and descriptions
-SANITY_STEPS=("do_pylint PYTHON2" "do_pylint PYTHON3" "do_check_futures_test" "do_buildifier" "do_bazel_nobuild" "do_pip_package_licenses_check" "do_lib_package_licenses_check" "do_java_package_licenses_check" "do_pip_smoke_test" "do_check_load_py_test" "do_code_link_check" "do_check_file_name_test")
-SANITY_STEPS_DESC=("Python 2 pylint" "Python 3 pylint" "Check that python files have certain __future__ imports" "buildifier check" "bazel nobuild" "pip: license check for external dependencies" "C library: license check for external dependencies" "Java Native Library: license check for external dependencies" "Pip Smoke Test: Checking py_test dependencies exist in pip package" "Check load py_test: Check that BUILD files with py_test target properly load py_test" "Code Link Check: Check there are no broken links" "Check file names for cases")
+SANITY_STEPS=("do_pylint PYTHON2" "do_pylint PYTHON3" "do_check_futures_test" "do_buildifier" "do_bazel_nobuild" "do_pip_package_licenses_check" "do_lib_package_licenses_check" "do_java_package_licenses_check" "do_pip_smoke_test" "do_check_load_py_test" "do_code_link_check" "do_check_file_name_test" "do_libtensorflow_framework_not_depend_on_cuda_check")
+SANITY_STEPS_DESC=("Python 2 pylint" "Python 3 pylint" "Check that python files have certain __future__ imports" "buildifier check" "bazel nobuild" "pip: license check for external dependencies" "C library: license check for external dependencies" "Java Native Library: license check for external dependencies" "Pip Smoke Test: Checking py_test dependencies exist in pip package" "Check load py_test: Check that BUILD files with py_test target properly load py_test" "Code Link Check: Check there are no broken links" "Check file names for cases" "Check gpu libtensorflow_framework.so does not depend on cuda shared libraries.")
INCREMENTAL_FLAG=""
DEFAULT_BAZEL_CONFIGS=""
diff --git a/tensorflow/tools/ci_build/install/install_deb_packages.sh b/tensorflow/tools/ci_build/install/install_deb_packages.sh
index 989f2a9..bd81001 100755
--- a/tensorflow/tools/ci_build/install/install_deb_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_deb_packages.sh
@@ -68,12 +68,6 @@
zip \
zlib1g-dev
-apt-get update && \
- apt-get install nvinfer-runtime-trt-repo-ubuntu1604-4.0.1-ga-cuda9.0 && \
- apt-get update && \
- apt-get install libnvinfer4=4.1.2-1+cuda9.0 && \
- apt-get install libnvinfer-dev=4.1.2-1+cuda9.0
-
# populate the database
updatedb
diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh
index 3878452..0fa8c6c 100755
--- a/tensorflow/tools/ci_build/install/install_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh
@@ -40,8 +40,8 @@
pip3 install virtualenv
# Install six.
-pip2 install --upgrade six==1.10.0
-pip3 install --upgrade six==1.10.0
+pip2 install --upgrade six==1.11.0
+pip3 install --upgrade six==1.11.0
# Install absl-py.
pip2 install --upgrade absl-py
@@ -97,9 +97,9 @@
pip2 install pylint==1.6.4
pip3 install pylint==1.6.4
-# pep8 tests require the following:
-pip2 install pep8
-pip3 install pep8
+# pycodestyle tests require the following:
+pip2 install pycodestyle
+pip3 install pycodestyle
# tf.mock require the following for python2:
pip2 install mock
diff --git a/tensorflow/tools/compatibility/BUILD b/tensorflow/tools/compatibility/BUILD
index 31dbc02..74ef9ec 100644
--- a/tensorflow/tools/compatibility/BUILD
+++ b/tensorflow/tools/compatibility/BUILD
@@ -34,6 +34,13 @@
name = "tf_upgrade",
srcs = ["tf_upgrade.py"],
srcs_version = "PY2AND3",
+ deps = [":tf_upgrade_lib"],
+)
+
+py_library(
+ name = "tf_upgrade_lib",
+ srcs = ["tf_upgrade.py"],
+ srcs_version = "PY2AND3",
deps = [":ast_edits"],
)
@@ -42,7 +49,7 @@
srcs = ["tf_upgrade_test.py"],
srcs_version = "PY2AND3",
deps = [
- ":tf_upgrade",
+ ":tf_upgrade_lib",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib",
"@six_archive//:six",
@@ -89,7 +96,7 @@
srcs = ["tf_upgrade_v2_test.py"],
srcs_version = "PY2AND3",
deps = [
- ":tf_upgrade_v2",
+ ":tf_upgrade_v2_lib",
"//tensorflow:tensorflow_py",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_test_lib",
diff --git a/tensorflow/tools/compatibility/ast_edits.py b/tensorflow/tools/compatibility/ast_edits.py
index 7c11cc1..ed388b8 100644
--- a/tensorflow/tools/compatibility/ast_edits.py
+++ b/tensorflow/tools/compatibility/ast_edits.py
@@ -116,8 +116,8 @@
* `function_warnings`: maps full names of functions to warnings that will be
printed out if the function is used. (e.g. tf.nn.convolution())
* `function_transformers`: maps function names to custom handlers
- * `leftover_warnings`: These warnings are printed if a matching Attribute
- still exists after all other transformations have run.
+ * `module_deprecations`: maps module names to warnings that will be printed
+ if the module is still used after all other transformations have run
For an example, see `TFAPIChangeSpec`.
"""
diff --git a/tensorflow/tools/compatibility/renames_v2.py b/tensorflow/tools/compatibility/renames_v2.py
index 011be51..c29b259 100644
--- a/tensorflow/tools/compatibility/renames_v2.py
+++ b/tensorflow/tools/compatibility/renames_v2.py
@@ -150,6 +150,7 @@
'tf.disable_v2_batch_normalization': 'tf.compat.v1.disable_v2_batch_normalization',
'tf.disable_v2_behavior': 'tf.compat.v1.disable_v2_behavior',
'tf.disable_v2_tensorshape': 'tf.compat.v1.disable_v2_tensorshape',
+ 'tf.distribute.get_loss_reduction': 'tf.compat.v1.distribute.get_loss_reduction',
'tf.distributions.Bernoulli': 'tf.compat.v1.distributions.Bernoulli',
'tf.distributions.Beta': 'tf.compat.v1.distributions.Beta',
'tf.distributions.Categorical': 'tf.compat.v1.distributions.Categorical',
@@ -268,7 +269,6 @@
'tf.keras.initializers.Orthogonal': 'tf.compat.v1.keras.initializers.Orthogonal',
'tf.keras.initializers.TruncatedNormal': 'tf.compat.v1.keras.initializers.TruncatedNormal',
'tf.keras.initializers.VarianceScaling': 'tf.compat.v1.keras.initializers.VarianceScaling',
- 'tf.keras.initializers.constant': 'tf.compat.v1.keras.initializers.constant',
'tf.keras.initializers.glorot_normal': 'tf.compat.v1.keras.initializers.glorot_normal',
'tf.keras.initializers.glorot_uniform': 'tf.compat.v1.keras.initializers.glorot_uniform',
'tf.keras.initializers.he_normal': 'tf.compat.v1.keras.initializers.he_normal',
@@ -277,13 +277,11 @@
'tf.keras.initializers.lecun_normal': 'tf.compat.v1.keras.initializers.lecun_normal',
'tf.keras.initializers.lecun_uniform': 'tf.compat.v1.keras.initializers.lecun_uniform',
'tf.keras.initializers.normal': 'tf.compat.v1.keras.initializers.normal',
- 'tf.keras.initializers.ones': 'tf.compat.v1.keras.initializers.ones',
'tf.keras.initializers.orthogonal': 'tf.compat.v1.keras.initializers.orthogonal',
'tf.keras.initializers.random_normal': 'tf.compat.v1.keras.initializers.random_normal',
'tf.keras.initializers.random_uniform': 'tf.compat.v1.keras.initializers.random_uniform',
'tf.keras.initializers.truncated_normal': 'tf.compat.v1.keras.initializers.truncated_normal',
'tf.keras.initializers.uniform': 'tf.compat.v1.keras.initializers.uniform',
- 'tf.keras.initializers.zeros': 'tf.compat.v1.keras.initializers.zeros',
'tf.layers.AveragePooling1D': 'tf.compat.v1.layers.AveragePooling1D',
'tf.layers.AveragePooling2D': 'tf.compat.v1.layers.AveragePooling2D',
'tf.layers.AveragePooling3D': 'tf.compat.v1.layers.AveragePooling3D',
@@ -432,6 +430,7 @@
'tf.nn.depthwise_conv2d_native_backprop_input': 'tf.nn.depthwise_conv2d_backprop_input',
'tf.nn.dynamic_rnn': 'tf.compat.v1.nn.dynamic_rnn',
'tf.nn.log_uniform_candidate_sampler': 'tf.random.log_uniform_candidate_sampler',
+ 'tf.nn.max_pool_v2': 'tf.nn.max_pool',
'tf.nn.quantized_avg_pool': 'tf.compat.v1.nn.quantized_avg_pool',
'tf.nn.quantized_conv2d': 'tf.compat.v1.nn.quantized_conv2d',
'tf.nn.quantized_max_pool': 'tf.compat.v1.nn.quantized_max_pool',
@@ -444,6 +443,7 @@
'tf.nn.rnn_cell.GRUCell': 'tf.compat.v1.nn.rnn_cell.GRUCell',
'tf.nn.rnn_cell.LSTMCell': 'tf.compat.v1.nn.rnn_cell.LSTMCell',
'tf.nn.rnn_cell.MultiRNNCell': 'tf.compat.v1.nn.rnn_cell.MultiRNNCell',
+ 'tf.nn.rnn_cell.ResidualWrapper': 'tf.compat.v1.nn.rnn_cell.ResidualWrapper',
'tf.nn.static_bidirectional_rnn': 'tf.compat.v1.nn.static_bidirectional_rnn',
'tf.nn.static_rnn': 'tf.compat.v1.nn.static_rnn',
'tf.nn.uniform_candidate_sampler': 'tf.random.uniform_candidate_sampler',
diff --git a/tensorflow/tools/compatibility/reorders_v2.py b/tensorflow/tools/compatibility/reorders_v2.py
index 01556b1..e7edf3f 100644
--- a/tensorflow/tools/compatibility/reorders_v2.py
+++ b/tensorflow/tools/compatibility/reorders_v2.py
@@ -28,15 +28,16 @@
reorders = {
'tf.argmax': ['input', 'axis', 'name', 'dimension', 'output_type'],
'tf.argmin': ['input', 'axis', 'name', 'dimension', 'output_type'],
- 'tf.batch_gather': ['params', 'indices', 'name'],
- 'tf.batch_to_space': ['input', 'crops', 'block_size', 'name'],
+ 'tf.batch_to_space': ['input', 'crops', 'block_size', 'name', 'block_shape'],
'tf.boolean_mask': ['tensor', 'mask', 'name', 'axis'],
'tf.cond': ['pred', 'true_fn', 'false_fn', 'strict', 'name', 'fn1', 'fn2'],
'tf.confusion_matrix': ['labels', 'predictions', 'num_classes', 'dtype', 'name', 'weights'],
- 'tf.convert_to_tensor': ['value', 'dtype', 'name', 'preferred_dtype'],
+ 'tf.convert_to_tensor': ['value', 'dtype', 'name', 'preferred_dtype', 'dtype_hint'],
'tf.decode_csv': ['records', 'record_defaults', 'field_delim', 'use_quote_delim', 'name', 'na_value', 'select_cols'],
'tf.depth_to_space': ['input', 'block_size', 'name', 'data_format'],
'tf.feature_column.categorical_column_with_vocabulary_file': ['key', 'vocabulary_file', 'vocabulary_size', 'num_oov_buckets', 'default_value', 'dtype'],
+ 'tf.gradients': ['ys', 'xs', 'grad_ys', 'name', 'colocate_gradients_with_ops', 'gate_gradients', 'aggregation_method', 'stop_gradients', 'unconnected_gradients'],
+ 'tf.hessians': ['ys', 'xs', 'name', 'colocate_gradients_with_ops', 'gate_gradients', 'aggregation_method'],
'tf.image.sample_distorted_bounding_box': ['image_size', 'bounding_boxes', 'seed', 'seed2', 'min_object_covered', 'aspect_ratio_range', 'area_range', 'max_attempts', 'use_image_if_no_bounding_boxes', 'name'],
'tf.io.decode_csv': ['records', 'record_defaults', 'field_delim', 'use_quote_delim', 'name', 'na_value', 'select_cols'],
'tf.io.parse_example': ['serialized', 'features', 'name', 'example_names'],
@@ -57,27 +58,28 @@
'tf.math.reduce_prod': ['input_tensor', 'axis', 'keepdims', 'name', 'reduction_indices', 'keep_dims'],
'tf.math.reduce_sum': ['input_tensor', 'axis', 'keepdims', 'name', 'reduction_indices', 'keep_dims'],
'tf.multinomial': ['logits', 'num_samples', 'seed', 'name', 'output_dtype'],
- 'tf.nn.conv1d': ['value', 'filters', 'stride', 'padding', 'use_cudnn_on_gpu', 'data_format', 'name'],
- 'tf.nn.conv2d': ['input', 'filter', 'strides', 'padding', 'use_cudnn_on_gpu', 'data_format', 'dilations', 'name'],
+ 'tf.nn.conv1d': ['value', 'filters', 'stride', 'padding', 'use_cudnn_on_gpu', 'data_format', 'name', 'input'],
+ 'tf.nn.conv2d': ['input', 'filter', 'strides', 'padding', 'use_cudnn_on_gpu', 'data_format', 'dilations', 'name', 'filters'],
'tf.nn.conv2d_backprop_filter': ['input', 'filter_sizes', 'out_backprop', 'strides', 'padding', 'use_cudnn_on_gpu', 'data_format', 'dilations', 'name'],
- 'tf.nn.conv2d_backprop_input': ['input_sizes', 'filter', 'out_backprop', 'strides', 'padding', 'use_cudnn_on_gpu', 'data_format', 'dilations', 'name'],
- 'tf.nn.convolution': ['input', 'filter', 'padding', 'strides', 'dilation_rate', 'name', 'data_format'],
+ 'tf.nn.conv2d_backprop_input': ['input_sizes', 'filter', 'out_backprop', 'strides', 'padding', 'use_cudnn_on_gpu', 'data_format', 'dilations', 'name', 'filters'],
+ 'tf.nn.convolution': ['input', 'filter', 'padding', 'strides', 'dilation_rate', 'name', 'data_format', 'filters', 'dilations'],
'tf.nn.crelu': ['features', 'name', 'axis'],
'tf.nn.ctc_beam_search_decoder': ['inputs', 'sequence_length', 'beam_width', 'top_paths', 'merge_repeated'],
'tf.nn.depth_to_space': ['input', 'block_size', 'name', 'data_format'],
- 'tf.nn.depthwise_conv2d': ['input', 'filter', 'strides', 'padding', 'rate', 'name', 'data_format'],
+ 'tf.nn.depthwise_conv2d': ['input', 'filter', 'strides', 'padding', 'rate', 'name', 'data_format', 'dilations'],
'tf.nn.embedding_lookup': ['params', 'ids', 'partition_strategy', 'name', 'validate_indices', 'max_norm'],
'tf.nn.embedding_lookup_sparse': ['params', 'sp_ids', 'sp_weights', 'partition_strategy', 'name', 'combiner', 'max_norm'],
'tf.nn.fractional_avg_pool': ['value', 'pooling_ratio', 'pseudo_random', 'overlapping', 'deterministic', 'seed', 'seed2', 'name'],
'tf.nn.fractional_max_pool': ['value', 'pooling_ratio', 'pseudo_random', 'overlapping', 'deterministic', 'seed', 'seed2', 'name'],
'tf.nn.in_top_k': ['predictions', 'targets', 'k', 'name'],
- 'tf.nn.moments': ['x', 'axes', 'shift', 'name', 'keep_dims'],
- 'tf.nn.pool': ['input', 'window_shape', 'pooling_type', 'padding', 'dilation_rate', 'strides', 'name', 'data_format'],
- 'tf.nn.separable_conv2d': ['input', 'depthwise_filter', 'pointwise_filter', 'strides', 'padding', 'rate', 'name', 'data_format'],
- 'tf.nn.softmax_cross_entropy_with_logits': ['_sentinel', 'labels', 'logits', 'dim', 'name'],
- 'tf.nn.space_to_batch': ['input', 'paddings', 'block_size', 'name'],
+ 'tf.nn.max_pool': ['value', 'ksize', 'strides', 'padding', 'data_format', 'name', 'input'],
+ 'tf.nn.moments': ['x', 'axes', 'shift', 'name', 'keep_dims', 'keepdims'],
+ 'tf.nn.pool': ['input', 'window_shape', 'pooling_type', 'padding', 'dilation_rate', 'strides', 'name', 'data_format', 'dilations'],
+ 'tf.nn.separable_conv2d': ['input', 'depthwise_filter', 'pointwise_filter', 'strides', 'padding', 'rate', 'name', 'data_format', 'dilations'],
+ 'tf.nn.softmax_cross_entropy_with_logits': ['_sentinel', 'labels', 'logits', 'dim', 'name', 'axis'],
+ 'tf.nn.space_to_batch': ['input', 'paddings', 'block_size', 'name', 'block_shape'],
'tf.nn.space_to_depth': ['input', 'block_size', 'name', 'data_format'],
- 'tf.nn.weighted_moments': ['x', 'axes', 'frequency_weights', 'name', 'keep_dims'],
+ 'tf.nn.weighted_moments': ['x', 'axes', 'frequency_weights', 'name', 'keep_dims', 'keepdims'],
'tf.norm': ['tensor', 'ord', 'axis', 'keepdims', 'name', 'keep_dims'],
'tf.pad': ['tensor', 'paddings', 'mode', 'name', 'constant_values'],
'tf.parse_example': ['serialized', 'features', 'name', 'example_names'],
@@ -88,7 +90,7 @@
'tf.random_poisson': ['lam', 'shape', 'dtype', 'seed', 'name'],
'tf.reduce_all': ['input_tensor', 'axis', 'keepdims', 'name', 'reduction_indices', 'keep_dims'],
'tf.reduce_any': ['input_tensor', 'axis', 'keepdims', 'name', 'reduction_indices', 'keep_dims'],
- 'tf.reduce_join': ['inputs', 'axis', 'keep_dims', 'separator', 'name', 'reduction_indices'],
+ 'tf.reduce_join': ['inputs', 'axis', 'keep_dims', 'separator', 'name', 'reduction_indices', 'keepdims'],
'tf.reduce_logsumexp': ['input_tensor', 'axis', 'keepdims', 'name', 'reduction_indices', 'keep_dims'],
'tf.reduce_max': ['input_tensor', 'axis', 'keepdims', 'name', 'reduction_indices', 'keep_dims'],
'tf.reduce_mean': ['input_tensor', 'axis', 'keepdims', 'name', 'reduction_indices', 'keep_dims'],
@@ -100,17 +102,17 @@
'tf.serialize_sparse': ['sp_input', 'name', 'out_type'],
'tf.shape': ['input', 'name', 'out_type'],
'tf.size': ['input', 'name', 'out_type'],
- 'tf.space_to_batch': ['input', 'paddings', 'block_size', 'name'],
+ 'tf.space_to_batch': ['input', 'paddings', 'block_size', 'name', 'block_shape'],
'tf.space_to_depth': ['input', 'block_size', 'name', 'data_format'],
'tf.sparse.add': ['a', 'b', 'threshold', 'thresh'],
- 'tf.sparse.concat': ['axis', 'sp_inputs', 'name', 'expand_nonconcat_dim', 'concat_dim'],
+ 'tf.sparse.concat': ['axis', 'sp_inputs', 'name', 'expand_nonconcat_dim', 'concat_dim', 'expand_nonconcat_dims'],
'tf.sparse.reduce_max': ['sp_input', 'axis', 'keepdims', 'reduction_axes', 'keep_dims'],
'tf.sparse.segment_mean': ['data', 'indices', 'segment_ids', 'name', 'num_segments'],
'tf.sparse.segment_sqrt_n': ['data', 'indices', 'segment_ids', 'name', 'num_segments'],
'tf.sparse.segment_sum': ['data', 'indices', 'segment_ids', 'name', 'num_segments'],
'tf.sparse.split': ['keyword_required', 'sp_input', 'num_split', 'axis', 'name', 'split_dim'],
'tf.sparse_add': ['a', 'b', 'threshold', 'thresh'],
- 'tf.sparse_concat': ['axis', 'sp_inputs', 'name', 'expand_nonconcat_dim', 'concat_dim'],
+ 'tf.sparse_concat': ['axis', 'sp_inputs', 'name', 'expand_nonconcat_dim', 'concat_dim', 'expand_nonconcat_dims'],
'tf.sparse_matmul': ['a', 'b', 'transpose_a', 'transpose_b', 'a_is_sparse', 'b_is_sparse', 'name'],
'tf.sparse_reduce_max': ['sp_input', 'axis', 'keepdims', 'reduction_axes', 'keep_dims'],
'tf.sparse_segment_mean': ['data', 'indices', 'segment_ids', 'name', 'num_segments'],
@@ -118,7 +120,7 @@
'tf.sparse_segment_sum': ['data', 'indices', 'segment_ids', 'name', 'num_segments'],
'tf.sparse_split': ['keyword_required', 'sp_input', 'num_split', 'axis', 'name', 'split_dim'],
'tf.strings.length': ['input', 'name', 'unit'],
- 'tf.strings.reduce_join': ['inputs', 'axis', 'keep_dims', 'separator', 'name', 'reduction_indices'],
+ 'tf.strings.reduce_join': ['inputs', 'axis', 'keep_dims', 'separator', 'name', 'reduction_indices', 'keepdims'],
'tf.strings.substr': ['input', 'pos', 'len', 'name', 'unit'],
'tf.substr': ['input', 'pos', 'len', 'name', 'unit'],
'tf.test.assert_equal_graph_def': ['actual', 'expected', 'checkpoint_v2'],
diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2.py b/tensorflow/tools/compatibility/tf_upgrade_v2.py
index 74d2973..18271ef 100644
--- a/tensorflow/tools/compatibility/tf_upgrade_v2.py
+++ b/tensorflow/tools/compatibility/tf_upgrade_v2.py
@@ -223,6 +223,9 @@
"tf.nn.max_pool_with_argmax": {
"Targmax": "output_dtype",
},
+ "tf.nn.max_pool": {
+ "value": "input"
+ },
"tf.multinomial": {
"output_dtype": "dtype",
},
@@ -537,6 +540,8 @@
"tf.data.experimental.unbatch",
"tf.contrib.data.unique":
"tf.data.experimental.unique",
+ "tf.contrib.framework.is_tensor":
+ "tf.is_tensor",
"tf.contrib.framework.nest.assert_same_structure":
"tf.nest.assert_same_structure",
"tf.contrib.framework.nest.flatten":
@@ -697,6 +702,14 @@
"tf.compat.v1.assert_rank",
"tf.contrib.framework.argsort":
"tf.argsort",
+ "tf.nn.max_pool":
+ "tf.nn.max_pool2d",
+ 'tf.keras.initializers.zeros':
+ 'tf.compat.v1.keras.initializers.zeros',
+ 'tf.keras.initializers.ones':
+ 'tf.compat.v1.keras.initializers.ones',
+ 'tf.keras.initializers.constant':
+ 'tf.compat.v1.keras.initializers.constant',
}
# pylint: enable=line-too-long
@@ -723,7 +736,6 @@
"tf.io.serialize_many_sparse",
"tf.argmax",
"tf.argmin",
- "tf.batch_gather",
"tf.batch_to_space",
"tf.cond",
"tf.nn.space_to_batch",
@@ -808,6 +820,9 @@
"tf.nn.fractional_avg_pool",
"tf.nn.fractional_max_pool",
"tf.image.sample_distorted_bounding_box",
+ "tf.gradients",
+ "tf.hessians",
+ "tf.nn.max_pool",
}
# Functions that were reordered should be changed to the new keyword args
@@ -823,14 +838,15 @@
"the required code."
)
+ flags_warning = (
+ ast_edits.ERROR,
+ "tf.flags has been removed, please use the argparse or absl"
+ " modules if you need command line parsing.")
+
decay_function_comment = (
ast_edits.INFO,
- "<function name> has been changed to return a callable instead "
- "of a tensor when graph building, but its functionality remains "
- "unchanged during eager execution (returns a callable like "
- "before). The converter cannot detect and fix this reliably, so "
- "this usage has been converted to compat.v1 (even though it may already"
- " be correct).\n"
+ "To use learning rate decay schedules with TensorFlow 2.0, switch to "
+ "the schedules in `tf.keras.optimizers.schedules`.\n"
)
assert_return_type_comment = (
@@ -976,10 +992,6 @@
assert_rank_comment,
"tf.debugging.assert_rank_in":
assert_rank_comment,
- "tf.flags": (
- ast_edits.ERROR,
- "tf.flags has been removed, please use the argparse or absl"
- " modules if you need command line parsing."),
"tf.train.exponential_decay":
decay_function_comment,
"tf.train.piecewise_constant_decay":
@@ -1266,7 +1278,6 @@
"*.make_initializable_iterator": _iterator_transformer,
"*.make_one_shot_iterator": _iterator_transformer,
"tf.nn.dropout": _dropout_transformer,
- "tf.batch_gather": _batch_gather_transformer,
"tf.to_bfloat16": _cast_transformer,
"tf.to_complex128": _cast_transformer,
"tf.to_complex64": _cast_transformer,
@@ -1325,6 +1336,7 @@
self.module_deprecations = {
"tf.contrib": contrib_warning,
+ "tf.flags": flags_warning,
}
@@ -1576,24 +1588,6 @@
return node
-def _batch_gather_transformer(parent, node, full_name, name, logs):
- """Add batch_dims argument for gather calls."""
- # Check if the call already has a batch_dims argument
- if any([kw.arg == "batch_dims" for kw in node.keywords]):
- logs.append((ast_edits.INFO, node.lineno, node.col_offset,
- "tf.batch_gather already has batch_dims argument. Neat."))
- return None
-
- minus_one = ast.Num(n=-1)
- minus_one.lineno = 0
- minus_one.col_offset = 0
- new_arg = ast.keyword("batch_dims", minus_one)
- node.keywords.append(new_arg)
- logs.append((ast_edits.INFO, node.lineno, node.col_offset,
- "Added keyword argument batch_dims=-1 to tf.batch_gather."))
- return node
-
-
def _image_resize_transformer(parent, node, full_name, name, logs):
"""Transforms image.resize_* to image.resize(..., method=*, ...)."""
resize_method = name[7:].upper()
diff --git a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py
index 440c163..ede0ccf 100644
--- a/tensorflow/tools/compatibility/tf_upgrade_v2_test.py
+++ b/tensorflow/tools/compatibility/tf_upgrade_v2_test.py
@@ -409,7 +409,8 @@
text = "%s(a, b)\n" % decay
_, report, unused_errors, _ = self._upgrade(text)
- self.assertIn("%s has been changed to return a callable" % decay, report)
+ self.assertIn("switch to the schedules in "
+ "`tf.keras.optimizers.schedules`", report)
def testMetrics(self):
metrics = [
@@ -653,16 +654,22 @@
self.assertEqual(errors, [])
def testColocateGradientsWithOps(self):
- text = "tf.gradients(a, foo=False)\n"
+ text = "tf.gradients(yx=a, foo=False)\n"
_, unused_report, errors, new_text = self._upgrade(text)
self.assertEqual(text, new_text)
self.assertEqual(errors, [])
- text = "tf.gradients(a, colocate_gradients_with_ops=False)\n"
+ text = "tf.gradients(yx=a, colocate_gradients_with_ops=False)\n"
_, report, unused_errors, new_text = self._upgrade(text)
- self.assertEqual("tf.gradients(a)\n", new_text)
+ self.assertEqual("tf.gradients(yx=a)\n", new_text)
self.assertIn("tf.gradients no longer takes", report)
+ text = "tf.gradients(y, x, grad_ys, name, colocate, gate)\n"
+ expected = ("tf.gradients(ys=y, xs=x, grad_ys=grad_ys, name=name, "
+ "gate_gradients=gate)\n")
+ _, unused_report, errors, new_text = self._upgrade(text)
+ self.assertEqual(expected, new_text)
+
def testColocateGradientsWithOpsMinimize(self):
text = "optimizer.minimize(a, foo=False)\n"
_, unused_report, errors, new_text = self._upgrade(text)
@@ -849,6 +856,46 @@
_, unused_report, unused_errors, new_text = self._upgrade(text)
self.assertEqual(new_text, expected_text)
+ def testConv2D(self):
+ text = (
+ "tf.nn.conv2d(input, filter, strides, padding, use_cudnn_on_gpu, "
+ "data_format)")
+ expected_text = (
+ "tf.nn.conv2d(input=input, filters=filter, strides=strides, "
+ "padding=padding, data_format=data_format)")
+ _, unused_report, unused_errors, new_text = self._upgrade(text)
+ self.assertEqual(new_text, expected_text)
+
+ text = (
+ "tf.nn.conv2d(input, filter=filter, strides=strides, padding=padding, "
+ "use_cudnn_on_gpu=use_cudnn_on_gpu)")
+ expected_text = ("tf.nn.conv2d(input=input, filters=filter, "
+ "strides=strides, padding=padding)")
+ _, unused_report, unused_errors, new_text = self._upgrade(text)
+ self.assertEqual(new_text, expected_text)
+
+ def testConv2DBackpropFilter(self):
+ text = (
+ "tf.nn.conv2d_backprop_filter(input, filter_sizes, out_backprop, "
+ "strides, padding, use_cudnn_on_gpu, data_format)")
+ expected_text = (
+ "tf.nn.conv2d_backprop_filter(input=input, filter_sizes=filter_sizes, "
+ "out_backprop=out_backprop, strides=strides, padding=padding, "
+ "data_format=data_format)")
+ _, unused_report, unused_errors, new_text = self._upgrade(text)
+ self.assertEqual(new_text, expected_text)
+
+ def testConv2DBackpropInput(self):
+ text = (
+ "tf.nn.conv2d_backprop_input(input_sizes, filter, out_backprop, "
+ "strides, padding, use_cudnn_on_gpu, data_format)")
+ expected_text = (
+ "tf.nn.conv2d_backprop_input(input_sizes=input_sizes, filters=filter, "
+ "out_backprop=out_backprop, strides=strides, padding=padding, "
+ "data_format=data_format)")
+ _, unused_report, unused_errors, new_text = self._upgrade(text)
+ self.assertEqual(new_text, expected_text)
+
def testSpacetoBatch(self):
text = "tf.space_to_batch_nd(input, shape, paddings, name)"
expected_text = "tf.space_to_batch(input, shape, paddings, name)"
@@ -947,19 +994,6 @@
_, unused_report, unused_errors, new_text = self._upgrade(text)
self.assertEqual(new_text, expected_text)
- def testBatchGather(self):
- text = "tf.batch_gather(foo, bar)"
- expected_text1 = "tf.gather(params=foo, indices=bar, batch_dims=-1)"
- expected_text2 = "tf.gather(batch_dims=-1, params=foo, indices=bar)"
- _, unused_report, unused_errors, new_text = self._upgrade(text)
- self.assertIn(new_text, [expected_text1, expected_text2])
-
- text = "tf.batch_gather(params=foo, indices=bar)"
- expected_text1 = "tf.gather(params=foo, indices=bar, batch_dims=-1)"
- expected_text2 = "tf.gather(batch_dims=-1, params=foo, indices=bar)"
- _, unused_report, unused_errors, new_text = self._upgrade(text)
- self.assertIn(new_text, [expected_text1, expected_text2])
-
def testIterators(self):
for (text, expected) in [
("(expr + yielding(data)).make_one_shot_iterator()",
@@ -1110,6 +1144,12 @@
_, _, _, new_text = self._upgrade(text)
self.assertEqual(expected, new_text)
+ def test_is_tensor_upgrade(self):
+ text = "tf.contrib.framework.is_tensor(x)"
+ expected = "tf.is_tensor(x)"
+ _, _, _, new_text = self._upgrade(text)
+ self.assertEqual(expected, new_text)
+
def test_sample_distorted_bounding_box(self):
# pylint: disable=line-too-long
text = "tf.image.sample_distorted_bounding_box(a, b, c, d, e, f, g, h, i, j)"
@@ -1125,6 +1165,20 @@
_, _, _, new_text = self._upgrade(text)
self.assertEqual(expected, new_text)
+ def test_flags_bare(self):
+ _, _, errors, _ = self._upgrade("tf.flags")
+ self.assertIn("tf.flags has been removed", errors[0])
+
+ def test_flags_flags(self):
+ _, _, errors, _ = self._upgrade("tf.flags.FLAGS")
+ self.assertIn("tf.flags has been removed", errors[0])
+
+ def test_max_pool_2d(self):
+ text = "tf.nn.max_pool(value=4)"
+ expected_text = "tf.nn.max_pool2d(input=4)"
+ _, _, _, new_text = self._upgrade(text)
+ self.assertEqual(expected_text, new_text)
+
class TestUpgradeFiles(test_util.TensorFlowTestCase):
diff --git a/tensorflow/tools/dist_test/server/BUILD b/tensorflow/tools/dist_test/server/BUILD
index 3aa53a5..56810ae 100644
--- a/tensorflow/tools/dist_test/server/BUILD
+++ b/tensorflow/tools/dist_test/server/BUILD
@@ -12,6 +12,14 @@
py_binary(
name = "grpc_tensorflow_server",
+ srcs = ["grpc_tensorflow_server.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [":grpc_tensorflow_server_lib"],
+)
+
+py_library(
+ name = "grpc_tensorflow_server_lib",
srcs = [
"grpc_tensorflow_server.py",
],
@@ -33,7 +41,7 @@
main = "parse_cluster_spec_test.py",
srcs_version = "PY2AND3",
deps = [
- ":grpc_tensorflow_server",
+ ":grpc_tensorflow_server_lib",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:client_testlib",
],
diff --git a/tensorflow/tools/docker/Dockerfile.devel-mkl b/tensorflow/tools/docker/Dockerfile.devel-mkl
index 4eefd31..32aa00b 100755
--- a/tensorflow/tools/docker/Dockerfile.devel-mkl
+++ b/tensorflow/tools/docker/Dockerfile.devel-mkl
@@ -3,13 +3,18 @@
LABEL maintainer="Clayne Robison <clayne.b.robison@intel.com>"
# These parameters can be overridden by parameterized_docker_build.sh
-ARG TF_BUILD_VERSION=r1.12
+ARG TF_BUILD_VERSION=r1.13
ARG PYTHON="python"
ARG PYTHON3_DEV=""
ARG WHL_DIR="/tmp/pip"
ARG PIP="pip"
-RUN apt-get update && apt-get install -y --no-install-recommends \
+RUN apt-get update && apt-get install -y --no-install-recommends --fix-missing \
+ ${PYTHON} \
+ ${PYTHON}-dev \
+ ${PYTHON}-pip \
+ ${PYTHON}-setuptools \
+ ${PYTHON}-wheel \
build-essential \
curl \
git \
@@ -17,35 +22,20 @@
libfreetype6-dev \
libhdf5-serial-dev \
libpng-dev \
- libzmq3-dev \
libssl-dev \
+ libzmq3-dev \
+ openjdk-8-jdk \
+ openjdk-8-jre-headless \
pkg-config \
rsync \
software-properties-common \
unzip \
zip \
zlib1g-dev \
- openjdk-8-jdk \
- openjdk-8-jre-headless
-
-#install Python 3
-RUN if [ ${PYTHON} = "python3.6" ]; then \
- curl https://www.python.org/ftp/python/3.6.5/Python-3.6.5.tar.xz -o /opt/python.tar.xz && \
- cd /opt && tar xvf python.tar.xz && \
- cd /opt/*/ && ./configure && \
- make && make install; \
- else \
- apt-get install -y --no-install-recommends \
- python-dev \
- ${PYTHON3_DEV}; \
- fi
-
-RUN apt-get clean && \
+ && \
+ apt-get clean && \
rm -rf /var/lib/apt/lists/*
-RUN curl -fSsL -O https://bootstrap.pypa.io/get-pip.py && \
- ${PYTHON} get-pip.py && \
- rm get-pip.py
RUN ${PIP} --no-cache-dir install \
Pillow \
@@ -57,17 +47,12 @@
matplotlib \
mock \
numpy \
+ pandas \
scipy \
sklearn \
- pandas \
&& \
${PYTHON} -m ipykernel.kernelspec
-RUN if [ "${PYTHON}" = "python3" ]; then \
- ln -s -f /usr/bin/python3 /usr/bin/python; \
- elif [ "${PYTHON}" = "python3.6" ]; then \
- ln -s -f /usr/local/bin/python3.6 /usr/bin/python; \
- fi
# Set up our notebook config.
COPY jupyter_notebook_config.py /root/.jupyter/
diff --git a/tensorflow/tools/docker/Dockerfile.devel-mkl-horovod b/tensorflow/tools/docker/Dockerfile.devel-mkl-horovod
index 3810dae..2114091 100755
--- a/tensorflow/tools/docker/Dockerfile.devel-mkl-horovod
+++ b/tensorflow/tools/docker/Dockerfile.devel-mkl-horovod
@@ -3,42 +3,43 @@
LABEL maintainer="Cong Xu <cong.xu@intel.com>"
# These parameters can be overridden by parameterized_docker_build.sh
-ARG TF_BUILD_VERSION=r1.11
+ARG TF_BUILD_VERSION=r1.13
ARG PYTHON="python"
ARG PYTHON3_DEV=""
ARG WHL_DIR="/tmp/pip"
ARG PIP="pip"
-RUN apt-get update && apt-get install -y --no-install-recommends \
+
+RUN apt-get update && apt-get install -y --no-install-recommends --fix-missing \
+ ${PYTHON} \
+ ${PYTHON}-dev \
+ ${PYTHON}-pip \
+ ${PYTHON}-setuptools \
+ ${PYTHON}-wheel \
build-essential \
curl \
git \
libcurl3-dev \
libfreetype6-dev \
libhdf5-serial-dev \
+ libnuma-dev \
libpng-dev \
libzmq3-dev \
+ openjdk-8-jdk \
+ openjdk-8-jre-headless \
+ openssh-client \
+ openssh-server \
pkg-config \
- python-dev \
- ${PYTHON3_DEV} \
rsync \
software-properties-common \
unzip \
+ wget \
zip \
zlib1g-dev \
- openjdk-8-jdk \
- openjdk-8-jre-headless \
- wget \
- libnuma-dev \
- openssh-client \
- openssh-server \
&& \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
-RUN curl -fSsL -O https://bootstrap.pypa.io/get-pip.py && \
- ${PYTHON} get-pip.py && \
- rm get-pip.py
RUN ${PIP} --no-cache-dir install \
Pillow \
@@ -56,9 +57,6 @@
&& \
${PYTHON} -m ipykernel.kernelspec
-RUN if [ "${PYTHON}" = "python3" ]; then \
- ln -s -f /usr/bin/python3 /usr/bin/python; \
- fi
# Set up our notebook config.
COPY jupyter_notebook_config.py /root/.jupyter/
diff --git a/tensorflow/tools/docker/Dockerfile.mkl b/tensorflow/tools/docker/Dockerfile.mkl
index dad2769..3f7729b 100755
--- a/tensorflow/tools/docker/Dockerfile.mkl
+++ b/tensorflow/tools/docker/Dockerfile.mkl
@@ -6,13 +6,18 @@
ARG TF_WHL_URL
# Optional parameters
-ARG TF_BUILD_VERSION=r1.9
+ARG TF_BUILD_VERSION=r1.13
ARG PYTHON="python"
ARG PYTHON_DEV="python-dev"
ARG PIP="pip"
# Pick up some TF dependencies
-RUN apt-get update && apt-get install -y --no-install-recommends \
+RUN apt-get update && apt-get install -y --no-install-recommends --fix-missing \
+ ${PYTHON} \
+ ${PYTHON}-dev \
+ ${PYTHON}-pip \
+ ${PYTHON}-setuptools \
+ ${PYTHON}-wheel \
build-essential \
curl \
libfreetype6-dev \
@@ -20,8 +25,6 @@
libpng-dev \
libzmq3-dev \
pkg-config \
- ${PYTHON} \
- ${PYTHON_DEV} \
rsync \
software-properties-common \
unzip \
@@ -29,9 +32,6 @@
apt-get clean && \
rm -rf /var/lib/apt/lists/*
-RUN curl -O https://bootstrap.pypa.io/get-pip.py && \
- ${PYTHON} get-pip.py && \
- rm get-pip.py
RUN ${PIP} --no-cache-dir install \
Pillow \
@@ -48,13 +48,11 @@
&& \
${PYTHON} -m ipykernel.kernelspec
+
COPY ${TF_WHL_URL} /
RUN ${PIP} install --no-cache-dir --force-reinstall /${TF_WHL_URL} && \
rm -rf /${TF_WHL_URL}
-RUN if [ "${PYTHON}" = "python3" ]; then \
- ln -s -f /usr/bin/python3 /usr/bin/python; \
- fi
# Set up our notebook config.
COPY jupyter_notebook_config.py /root/.jupyter/
diff --git a/tensorflow/tools/docker/Dockerfile.mkl-horovod b/tensorflow/tools/docker/Dockerfile.mkl-horovod
index 19dc45c..b0afd63 100755
--- a/tensorflow/tools/docker/Dockerfile.mkl-horovod
+++ b/tensorflow/tools/docker/Dockerfile.mkl-horovod
@@ -6,36 +6,36 @@
ARG TF_WHL_URL
# Optional parameters
-ARG TF_BUILD_VERSION=r1.11
+ARG TF_BUILD_VERSION=r1.13
ARG PYTHON="python"
ARG PYTHON_DEV="python-dev"
ARG PIP="pip"
# Pick up some TF dependencies
-RUN apt-get update && apt-get install -y --no-install-recommends \
+# RUN apt-get update && apt-get install -y --no-install-recommends --fix-missing \
+ ${PYTHON} \
+ ${PYTHON}-dev \
+ ${PYTHON}-pip \
+ ${PYTHON}-setuptools \
+ ${PYTHON}-wheel \
build-essential \
curl \
libfreetype6-dev \
libhdf5-serial-dev \
+ libnuma-dev \
libpng-dev \
libzmq3-dev \
+ openssh-client \
+ openssh-server \
pkg-config \
- python \
- ${PYTHON_DEV} \
rsync \
software-properties-common \
unzip \
wget \
- libnuma-dev \
- openssh-client \
- openssh-server \
&& \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
-RUN curl -O https://bootstrap.pypa.io/get-pip.py && \
- python get-pip.py && \
- rm get-pip.py
RUN ${PIP} --no-cache-dir install \
Pillow \
@@ -50,15 +50,13 @@
scipy \
sklearn \
&& \
- python -m ipykernel.kernelspec
+ ${PYTHON} -m ipykernel.kernelspec
+
COPY ${TF_WHL_URL} /
RUN ${PIP} install --no-cache-dir --force-reinstall /${TF_WHL_URL} && \
rm -rf /${TF_WHL_URL}
-RUN if [ "${PYTHON}" = "python3" ]; then \
- ln -s -f /usr/bin/python3 /usr/bin/python; \
- fi
# Set up our notebook config.
COPY jupyter_notebook_config.py /root/.jupyter/
diff --git a/tensorflow/tools/dockerfiles/assembler.py b/tensorflow/tools/dockerfiles/assembler.py
index 09537b7..83b72cb 100644
--- a/tensorflow/tools/dockerfiles/assembler.py
+++ b/tensorflow/tools/dockerfiles/assembler.py
@@ -34,6 +34,7 @@
import itertools
import multiprocessing
import os
+import platform
import re
import shutil
import sys
@@ -552,6 +553,13 @@
if not FLAGS.build_images:
continue
+ # Only build images for host architecture
+ proc_arch = platform.processor()
+ is_x86 = proc_arch.startswith('x86')
+ if (is_x86 and any([arch in tag for arch in ['ppc64le']]) or
+ not is_x86 and proc_arch not in tag):
+ continue
+
# Generate a temporary Dockerfile to use to build, since docker-py
# needs a filepath relative to the build context (i.e. the current
# directory)
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/cpu-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/cpu-jupyter.Dockerfile
index d8fabad..c09a958 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/cpu-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/cpu-jupyter.Dockerfile
@@ -54,6 +54,8 @@
RUN chmod a+rwx /etc/bash.bashrc
RUN ${PIP} install jupyter matplotlib
+RUN ${PIP} install jupyter_http_over_ws
+RUN jupyter serverextension enable --py jupyter_http_over_ws
RUN mkdir -p /tf/tensorflow-tutorials && chmod -R a+rwx /tf/
RUN mkdir /.local && chmod a+rwx /.local
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu-jupyter.Dockerfile
index c1f6daf..dc5b5d4 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu-jupyter.Dockerfile
@@ -30,7 +30,6 @@
libcurl3-dev \
libfreetype6-dev \
libhdf5-serial-dev \
- libpng12-dev \
libzmq3-dev \
pkg-config \
rsync \
@@ -43,12 +42,14 @@
&& \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
-
+
ENV CI_BUILD_PYTHON python
-# Check out TensorFlow source code if --build_arg CHECKOUT_TENSORFLOW=1
+# CACHE_STOP is used to rerun future commands, otherwise cloning tensorflow will be cached and will not pull the most recent version
+ARG CACHE_STOP=1
+# Check out TensorFlow source code if --build-arg CHECKOUT_TF_SRC=1
ARG CHECKOUT_TF_SRC=0
-RUN test "${CHECKOUT_TF_SRC}" -eq 1 && git clone https://github.com/tensorflow/tensorflow.git /tensorflow_src
+RUN test "${CHECKOUT_TF_SRC}" -eq 1 && git clone https://github.com/tensorflow/tensorflow.git /tensorflow_src || true
ARG USE_PYTHON_3_NOT_2
ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
@@ -105,6 +106,8 @@
RUN chmod a+rwx /etc/bash.bashrc
RUN ${PIP} install jupyter matplotlib
+RUN ${PIP} install jupyter_http_over_ws
+RUN jupyter serverextension enable --py jupyter_http_over_ws
RUN mkdir -p /tf/tensorflow-tutorials && chmod -R a+rwx /tf/
RUN mkdir /.local && chmod a+rwx /.local
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu.Dockerfile
index b4dfc8b..da81397 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/devel-cpu.Dockerfile
@@ -30,7 +30,6 @@
libcurl3-dev \
libfreetype6-dev \
libhdf5-serial-dev \
- libpng12-dev \
libzmq3-dev \
pkg-config \
rsync \
@@ -43,12 +42,14 @@
&& \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
-
+
ENV CI_BUILD_PYTHON python
-# Check out TensorFlow source code if --build_arg CHECKOUT_TENSORFLOW=1
+# CACHE_STOP is used to rerun future commands, otherwise cloning tensorflow will be cached and will not pull the most recent version
+ARG CACHE_STOP=1
+# Check out TensorFlow source code if --build-arg CHECKOUT_TF_SRC=1
ARG CHECKOUT_TF_SRC=0
-RUN test "${CHECKOUT_TF_SRC}" -eq 1 && git clone https://github.com/tensorflow/tensorflow.git /tensorflow_src
+RUN test "${CHECKOUT_TF_SRC}" -eq 1 && git clone https://github.com/tensorflow/tensorflow.git /tensorflow_src || true
ARG USE_PYTHON_3_NOT_2
ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu-jupyter.Dockerfile
index 6d76c06..73f7a74 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu-jupyter.Dockerfile
@@ -21,23 +21,28 @@
ARG UBUNTU_VERSION=16.04
-FROM nvidia/cuda:10.0-base-ubuntu${UBUNTU_VERSION} as base
+ARG ARCH=
+ARG CUDA=10.0
+FROM nvidia/cuda${ARCH:+-$ARCH}:${CUDA}-base-ubuntu${UBUNTU_VERSION} as base
+ARG CUDNN=7.4.1.5-1
+ARG LIB_DIR_PREFIX=x84_64
+# Needed for string substitution
+SHELL ["/bin/bash", "-c"]
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
- cuda-command-line-tools-10-0 \
- cuda-cublas-dev-10-0 \
- cuda-cudart-dev-10-0 \
- cuda-cufft-dev-10-0 \
- cuda-curand-dev-10-0 \
- cuda-cusolver-dev-10-0 \
- cuda-cusparse-dev-10-0 \
- libcudnn7=7.4.1.5-1+cuda10.0 \
- libcudnn7-dev=7.4.1.5-1+cuda10.0 \
+ cuda-command-line-tools-${CUDA/./-} \
+ cuda-cublas-dev-${CUDA/./-} \
+ cuda-cudart-dev-${CUDA/./-} \
+ cuda-cufft-dev-${CUDA/./-} \
+ cuda-curand-dev-${CUDA/./-} \
+ cuda-cusolver-dev-${CUDA/./-} \
+ cuda-cusparse-dev-${CUDA/./-} \
+ libcudnn7=${CUDNN}+cuda${CUDA} \
+ libcudnn7-dev=${CUDNN}+cuda${CUDA} \
libcurl3-dev \
libfreetype6-dev \
libhdf5-serial-dev \
- libpng12-dev \
libzmq3-dev \
pkg-config \
rsync \
@@ -48,14 +53,15 @@
wget \
git \
&& \
- find /usr/local/cuda-10.0/lib64/ -type f -name 'lib*_static.a' -not -name 'libcudart_static.a' -delete && \
- rm /usr/lib/x86_64-linux-gnu/libcudnn_static_v7.a
+ find /usr/local/cuda-${CUDA}/lib64/ -type f -name 'lib*_static.a' -not -name 'libcudart_static.a' -delete && \
+ rm /usr/lib/${LIB_DIR_PREFIX}-linux-gnu/libcudnn_static_v7.a
-RUN apt-get update && \
- apt-get install nvinfer-runtime-trt-repo-ubuntu1604-5.0.2-ga-cuda10.0 \
+RUN [ ${ARCH} = ppc64le ] || (apt-get update && \
+ apt-get install nvinfer-runtime-trt-repo-ubuntu1604-5.0.2-ga-cuda${CUDA} \
&& apt-get update \
- && apt-get install -y --no-install-recommends libnvinfer-dev=5.0.2-1+cuda10.0 \
- && rm -rf /var/lib/apt/lists/*
+ && apt-get install -y --no-install-recommends libnvinfer5=5.0.2-1+cuda${CUDA} \
+ && apt-get clean \
+ && rm -rf /var/lib/apt/lists/*)
# Configure the build for our CUDA configuration.
ENV CI_BUILD_PYTHON python
@@ -63,12 +69,13 @@
ENV TF_NEED_CUDA 1
ENV TF_NEED_TENSORRT 1
ENV TF_CUDA_COMPUTE_CAPABILITIES=3.5,5.2,6.0,6.1,7.0
-ENV TF_CUDA_VERSION=10.0
-ENV TF_CUDNN_VERSION=7
-
-# Check out TensorFlow source code if --build_arg CHECKOUT_TENSORFLOW=1
+ENV TF_CUDA_VERSION=${CUDA}
+ENV TF_CUDNN_VERSION=${CUDNN%%.*}
+# CACHE_STOP is used to rerun future commands, otherwise cloning tensorflow will be cached and will not pull the most recent version
+ARG CACHE_STOP=1
+# Check out TensorFlow source code if --build_arg CHECKOUT_TF_SRC=1
ARG CHECKOUT_TF_SRC=0
-RUN test "${CHECKOUT_TF_SRC}" -eq 1 && git clone https://github.com/tensorflow/tensorflow.git /tensorflow_src
+RUN test "${CHECKOUT_TF_SRC}" -eq 1 && git clone https://github.com/tensorflow/tensorflow.git /tensorflow_src || true
ARG USE_PYTHON_3_NOT_2
ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
@@ -125,6 +132,8 @@
RUN chmod a+rwx /etc/bash.bashrc
RUN ${PIP} install jupyter matplotlib
+RUN ${PIP} install jupyter_http_over_ws
+RUN jupyter serverextension enable --py jupyter_http_over_ws
RUN mkdir -p /tf/tensorflow-tutorials && chmod -R a+rwx /tf/
RUN mkdir /.local && chmod a+rwx /.local
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu.Dockerfile
index 160abc8..7ae5010 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/devel-gpu.Dockerfile
@@ -21,23 +21,28 @@
ARG UBUNTU_VERSION=16.04
-FROM nvidia/cuda:10.0-base-ubuntu${UBUNTU_VERSION} as base
+ARG ARCH=
+ARG CUDA=10.0
+FROM nvidia/cuda${ARCH:+-$ARCH}:${CUDA}-base-ubuntu${UBUNTU_VERSION} as base
+ARG CUDNN=7.4.1.5-1
+ARG LIB_DIR_PREFIX=x84_64
+# Needed for string substitution
+SHELL ["/bin/bash", "-c"]
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
- cuda-command-line-tools-10-0 \
- cuda-cublas-dev-10-0 \
- cuda-cudart-dev-10-0 \
- cuda-cufft-dev-10-0 \
- cuda-curand-dev-10-0 \
- cuda-cusolver-dev-10-0 \
- cuda-cusparse-dev-10-0 \
- libcudnn7=7.4.1.5-1+cuda10.0 \
- libcudnn7-dev=7.4.1.5-1+cuda10.0 \
+ cuda-command-line-tools-${CUDA/./-} \
+ cuda-cublas-dev-${CUDA/./-} \
+ cuda-cudart-dev-${CUDA/./-} \
+ cuda-cufft-dev-${CUDA/./-} \
+ cuda-curand-dev-${CUDA/./-} \
+ cuda-cusolver-dev-${CUDA/./-} \
+ cuda-cusparse-dev-${CUDA/./-} \
+ libcudnn7=${CUDNN}+cuda${CUDA} \
+ libcudnn7-dev=${CUDNN}+cuda${CUDA} \
libcurl3-dev \
libfreetype6-dev \
libhdf5-serial-dev \
- libpng12-dev \
libzmq3-dev \
pkg-config \
rsync \
@@ -48,14 +53,15 @@
wget \
git \
&& \
- find /usr/local/cuda-10.0/lib64/ -type f -name 'lib*_static.a' -not -name 'libcudart_static.a' -delete && \
- rm /usr/lib/x86_64-linux-gnu/libcudnn_static_v7.a
+ find /usr/local/cuda-${CUDA}/lib64/ -type f -name 'lib*_static.a' -not -name 'libcudart_static.a' -delete && \
+ rm /usr/lib/${LIB_DIR_PREFIX}-linux-gnu/libcudnn_static_v7.a
-RUN apt-get update && \
- apt-get install nvinfer-runtime-trt-repo-ubuntu1604-5.0.2-ga-cuda10.0 \
+RUN [ ${ARCH} = ppc64le ] || (apt-get update && \
+ apt-get install nvinfer-runtime-trt-repo-ubuntu1604-5.0.2-ga-cuda${CUDA} \
&& apt-get update \
- && apt-get install -y --no-install-recommends libnvinfer-dev=5.0.2-1+cuda10.0 \
- && rm -rf /var/lib/apt/lists/*
+ && apt-get install -y --no-install-recommends libnvinfer5=5.0.2-1+cuda${CUDA} \
+ && apt-get clean \
+ && rm -rf /var/lib/apt/lists/*)
# Configure the build for our CUDA configuration.
ENV CI_BUILD_PYTHON python
@@ -63,12 +69,13 @@
ENV TF_NEED_CUDA 1
ENV TF_NEED_TENSORRT 1
ENV TF_CUDA_COMPUTE_CAPABILITIES=3.5,5.2,6.0,6.1,7.0
-ENV TF_CUDA_VERSION=10.0
-ENV TF_CUDNN_VERSION=7
-
-# Check out TensorFlow source code if --build_arg CHECKOUT_TENSORFLOW=1
+ENV TF_CUDA_VERSION=${CUDA}
+ENV TF_CUDNN_VERSION=${CUDNN%%.*}
+# CACHE_STOP is used to rerun future commands, otherwise cloning tensorflow will be cached and will not pull the most recent version
+ARG CACHE_STOP=1
+# Check out TensorFlow source code if --build_arg CHECKOUT_TF_SRC=1
ARG CHECKOUT_TF_SRC=0
-RUN test "${CHECKOUT_TF_SRC}" -eq 1 && git clone https://github.com/tensorflow/tensorflow.git /tensorflow_src
+RUN test "${CHECKOUT_TF_SRC}" -eq 1 && git clone https://github.com/tensorflow/tensorflow.git /tensorflow_src || true
ARG USE_PYTHON_3_NOT_2
ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/gpu-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/gpu-jupyter.Dockerfile
index 46252c5..f686a71 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/gpu-jupyter.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/gpu-jupyter.Dockerfile
@@ -21,32 +21,37 @@
ARG UBUNTU_VERSION=16.04
-FROM nvidia/cuda:10.0-base-ubuntu${UBUNTU_VERSION} as base
+ARG ARCH=
+ARG CUDA=10.0
+FROM nvidia/cuda${ARCH:+-$ARCH}:${CUDA}-base-ubuntu${UBUNTU_VERSION} as base
+ARG CUDNN=7.4.1.5-1
+# Needed for string substitution
+SHELL ["/bin/bash", "-c"]
# Pick up some TF dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
- cuda-command-line-tools-10-0 \
- cuda-cublas-10-0 \
- cuda-cufft-10-0 \
- cuda-curand-10-0 \
- cuda-cusolver-10-0 \
- cuda-cusparse-10-0 \
- libcudnn7=7.4.1.5-1+cuda10.0 \
+ cuda-command-line-tools-${CUDA/./-} \
+ cuda-cublas-${CUDA/./-} \
+ cuda-cufft-${CUDA/./-} \
+ cuda-curand-${CUDA/./-} \
+ cuda-cusolver-${CUDA/./-} \
+ cuda-cusparse-${CUDA/./-} \
+ curl \
+ libcudnn7=${CUDNN}+cuda${CUDA} \
libfreetype6-dev \
libhdf5-serial-dev \
- libpng12-dev \
libzmq3-dev \
pkg-config \
software-properties-common \
unzip
-RUN apt-get update && \
- apt-get install nvinfer-runtime-trt-repo-ubuntu1604-5.0.2-ga-cuda10.0 \
+RUN [ ${ARCH} = ppc64le ] || (apt-get update && \
+ apt-get install nvinfer-runtime-trt-repo-ubuntu1604-5.0.2-ga-cuda${CUDA} \
&& apt-get update \
- && apt-get install -y --no-install-recommends libnvinfer5=5.0.2-1+cuda10.0 \
+ && apt-get install -y --no-install-recommends libnvinfer5=5.0.2-1+cuda${CUDA} \
&& apt-get clean \
- && rm -rf /var/lib/apt/lists/*
+ && rm -rf /var/lib/apt/lists/*)
# For CUDA profiling, TensorFlow requires CUPTI.
ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH
@@ -82,6 +87,8 @@
RUN chmod a+rwx /etc/bash.bashrc
RUN ${PIP} install jupyter matplotlib
+RUN ${PIP} install jupyter_http_over_ws
+RUN jupyter serverextension enable --py jupyter_http_over_ws
RUN mkdir -p /tf/tensorflow-tutorials && chmod -R a+rwx /tf/
RUN mkdir /.local && chmod a+rwx /.local
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/gpu.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/gpu.Dockerfile
index 80e427f..00664b6 100644
--- a/tensorflow/tools/dockerfiles/dockerfiles/gpu.Dockerfile
+++ b/tensorflow/tools/dockerfiles/dockerfiles/gpu.Dockerfile
@@ -21,32 +21,37 @@
ARG UBUNTU_VERSION=16.04
-FROM nvidia/cuda:10.0-base-ubuntu${UBUNTU_VERSION} as base
+ARG ARCH=
+ARG CUDA=10.0
+FROM nvidia/cuda${ARCH:+-$ARCH}:${CUDA}-base-ubuntu${UBUNTU_VERSION} as base
+ARG CUDNN=7.4.1.5-1
+# Needed for string substitution
+SHELL ["/bin/bash", "-c"]
# Pick up some TF dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
- cuda-command-line-tools-10-0 \
- cuda-cublas-10-0 \
- cuda-cufft-10-0 \
- cuda-curand-10-0 \
- cuda-cusolver-10-0 \
- cuda-cusparse-10-0 \
- libcudnn7=7.4.1.5-1+cuda10.0 \
+ cuda-command-line-tools-${CUDA/./-} \
+ cuda-cublas-${CUDA/./-} \
+ cuda-cufft-${CUDA/./-} \
+ cuda-curand-${CUDA/./-} \
+ cuda-cusolver-${CUDA/./-} \
+ cuda-cusparse-${CUDA/./-} \
+ curl \
+ libcudnn7=${CUDNN}+cuda${CUDA} \
libfreetype6-dev \
libhdf5-serial-dev \
- libpng12-dev \
libzmq3-dev \
pkg-config \
software-properties-common \
unzip
-RUN apt-get update && \
- apt-get install nvinfer-runtime-trt-repo-ubuntu1604-5.0.2-ga-cuda10.0 \
+RUN [ ${ARCH} = ppc64le ] || (apt-get update && \
+ apt-get install nvinfer-runtime-trt-repo-ubuntu1604-5.0.2-ga-cuda${CUDA} \
&& apt-get update \
- && apt-get install -y --no-install-recommends libnvinfer5=5.0.2-1+cuda10.0 \
+ && apt-get install -y --no-install-recommends libnvinfer5=5.0.2-1+cuda${CUDA} \
&& apt-get clean \
- && rm -rf /var/lib/apt/lists/*
+ && rm -rf /var/lib/apt/lists/*)
# For CUDA profiling, TensorFlow requires CUPTI.
ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/cpu-ppc64le-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/cpu-ppc64le-jupyter.Dockerfile
new file mode 100644
index 0000000..beb3292
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/cpu-ppc64le-jupyter.Dockerfile
@@ -0,0 +1,92 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+#
+# THIS IS A GENERATED DOCKERFILE.
+#
+# This file was assembled from multiple pieces, whose use is documented
+# throughout. Please refer to the TensorFlow dockerfiles documentation
+# for more information.
+
+ARG UBUNTU_VERSION=16.04
+
+FROM ubuntu:${UBUNTU_VERSION} as base
+
+ARG USE_PYTHON_3_NOT_2
+ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
+ARG PYTHON=python${_PY_SUFFIX}
+ARG PIP=pip${_PY_SUFFIX}
+
+# See http://bugs.python.org/issue19846
+ENV LANG C.UTF-8
+
+RUN apt-get update && apt-get install -y \
+ ${PYTHON} \
+ ${PYTHON}-pip
+
+RUN ${PIP} --no-cache-dir install --upgrade \
+ pip \
+ setuptools
+
+# Some TF tools expect a "python" binary
+RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
+
+# Options:
+# tensorflow
+# tensorflow-gpu
+# tf-nightly
+# tf-nightly-gpu
+ARG TF_PACKAGE=tensorflow
+RUN apt-get update && apt-get install -y wget libhdf5-dev
+RUN ${PIP} install --global-option=build_ext \
+ --global-option=-I/usr/include/hdf5/serial/ \
+ --global-option=-L/usr/lib/powerpc64le-linux-gnu/hdf5/serial \
+ h5py
+
+# CACHE_STOP is used to rerun future commands, otherwise downloading the .whl will be cached and will not pull the most recent version
+ARG CACHE_STOP=1
+RUN if [ ${TF_PACKAGE} = tensorflow-gpu ]; then \
+ BASE=https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/lastSuccessfulBuild/; \
+ elif [ ${TF_PACKAGE} = tf-nightly-gpu ]; then \
+ BASE=https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/lastSuccessfulBuild/; \
+ elif [ ${TF_PACKAGE} = tensorflow ]; then \
+ BASE=https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/lastSuccessfulBuild/; \
+ elif [ ${TF_PACKAGE} = tf-nightly ]; then \
+ BASE=https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/lastSuccessfulBuild/; \
+ fi; \
+ MAJOR=`${PYTHON} -c 'import sys; print(sys.version_info[0])'`; \
+ MINOR=`${PYTHON} -c 'import sys; print(sys.version_info[1])'`; \
+ PACKAGE=$(wget -qO- ${BASE}"api/xml?xpath=//fileName&wrapper=artifacts" | grep -o "[^<>]*cp${MAJOR}${MINOR}[^<>]*.whl"); \
+ wget ${BASE}"artifact/tensorflow_pkg/"${PACKAGE}; \
+ ${PIP} install ${PACKAGE}
+
+COPY bashrc /etc/bash.bashrc
+RUN chmod a+rwx /etc/bash.bashrc
+
+RUN ${PIP} install jupyter matplotlib
+
+RUN mkdir -p /tf/tensorflow-tutorials && chmod -R a+rwx /tf/
+RUN mkdir /.local && chmod a+rwx /.local
+RUN apt-get install -y --no-install-recommends wget
+WORKDIR /tf/tensorflow-tutorials
+RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/basic_classification.ipynb
+RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/basic_text_classification.ipynb
+COPY readme-for-jupyter.md README.md
+RUN apt-get autoremove -y && apt-get remove -y wget
+WORKDIR /tf
+EXPOSE 8888
+
+RUN ${PYTHON} -m ipykernel.kernelspec
+
+CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/tf --ip 0.0.0.0 --no-browser --allow-root"]
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/cpu-ppc64le.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/cpu-ppc64le.Dockerfile
new file mode 100644
index 0000000..083d61b
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/cpu-ppc64le.Dockerfile
@@ -0,0 +1,75 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+#
+# THIS IS A GENERATED DOCKERFILE.
+#
+# This file was assembled from multiple pieces, whose use is documented
+# throughout. Please refer to the TensorFlow dockerfiles documentation
+# for more information.
+
+ARG UBUNTU_VERSION=16.04
+
+FROM ubuntu:${UBUNTU_VERSION} as base
+
+ARG USE_PYTHON_3_NOT_2
+ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
+ARG PYTHON=python${_PY_SUFFIX}
+ARG PIP=pip${_PY_SUFFIX}
+
+# See http://bugs.python.org/issue19846
+ENV LANG C.UTF-8
+
+RUN apt-get update && apt-get install -y \
+ ${PYTHON} \
+ ${PYTHON}-pip
+
+RUN ${PIP} --no-cache-dir install --upgrade \
+ pip \
+ setuptools
+
+# Some TF tools expect a "python" binary
+RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
+
+# Options:
+# tensorflow
+# tensorflow-gpu
+# tf-nightly
+# tf-nightly-gpu
+ARG TF_PACKAGE=tensorflow
+RUN apt-get update && apt-get install -y wget libhdf5-dev
+RUN ${PIP} install --global-option=build_ext \
+ --global-option=-I/usr/include/hdf5/serial/ \
+ --global-option=-L/usr/lib/powerpc64le-linux-gnu/hdf5/serial \
+ h5py
+
+# CACHE_STOP is used to rerun future commands, otherwise downloading the .whl will be cached and will not pull the most recent version
+ARG CACHE_STOP=1
+RUN if [ ${TF_PACKAGE} = tensorflow-gpu ]; then \
+ BASE=https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/lastSuccessfulBuild/; \
+ elif [ ${TF_PACKAGE} = tf-nightly-gpu ]; then \
+ BASE=https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/lastSuccessfulBuild/; \
+ elif [ ${TF_PACKAGE} = tensorflow ]; then \
+ BASE=https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/lastSuccessfulBuild/; \
+ elif [ ${TF_PACKAGE} = tf-nightly ]; then \
+ BASE=https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/lastSuccessfulBuild/; \
+ fi; \
+ MAJOR=`${PYTHON} -c 'import sys; print(sys.version_info[0])'`; \
+ MINOR=`${PYTHON} -c 'import sys; print(sys.version_info[1])'`; \
+ PACKAGE=$(wget -qO- ${BASE}"api/xml?xpath=//fileName&wrapper=artifacts" | grep -o "[^<>]*cp${MAJOR}${MINOR}[^<>]*.whl"); \
+ wget ${BASE}"artifact/tensorflow_pkg/"${PACKAGE}; \
+ ${PIP} install ${PACKAGE}
+
+COPY bashrc /etc/bash.bashrc
+RUN chmod a+rwx /etc/bash.bashrc
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-cpu-ppc64le-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-cpu-ppc64le-jupyter.Dockerfile
new file mode 100644
index 0000000..1f32849
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-cpu-ppc64le-jupyter.Dockerfile
@@ -0,0 +1,125 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+#
+# THIS IS A GENERATED DOCKERFILE.
+#
+# This file was assembled from multiple pieces, whose use is documented
+# throughout. Please refer to the TensorFlow dockerfiles documentation
+# for more information.
+
+ARG UBUNTU_VERSION=16.04
+
+FROM ubuntu:${UBUNTU_VERSION} AS base
+
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ build-essential \
+ curl \
+ git \
+ libcurl3-dev \
+ libfreetype6-dev \
+ libhdf5-serial-dev \
+ libzmq3-dev \
+ pkg-config \
+ rsync \
+ software-properties-common \
+ unzip \
+ zip \
+ zlib1g-dev \
+ openjdk-8-jdk \
+ openjdk-8-jre-headless \
+ && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
+ENV CI_BUILD_PYTHON python
+
+# CACHE_STOP is used to rerun future commands, otherwise cloning tensorflow will be cached and will not pull the most recent version
+ARG CACHE_STOP=1
+# Check out TensorFlow source code if --build_arg CHECKOUT_TF_SRC=1
+ARG CHECKOUT_TF_SRC=0
+RUN test "${CHECKOUT_TF_SRC}" -eq 1 && git clone https://github.com/tensorflow/tensorflow.git /tensorflow_src || true
+
+ARG USE_PYTHON_3_NOT_2
+ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
+ARG PYTHON=python${_PY_SUFFIX}
+ARG PIP=pip${_PY_SUFFIX}
+
+# See http://bugs.python.org/issue19846
+ENV LANG C.UTF-8
+
+RUN apt-get update && apt-get install -y \
+ ${PYTHON} \
+ ${PYTHON}-pip
+
+RUN ${PIP} --no-cache-dir install --upgrade \
+ pip \
+ setuptools
+
+# Some TF tools expect a "python" binary
+RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
+
+RUN apt-get update && apt-get install -y \
+ build-essential \
+ curl \
+ git \
+ openjdk-8-jdk \
+ ${PYTHON}-dev \
+ swig
+
+RUN ${PIP} --no-cache-dir install \
+ Pillow \
+ h5py \
+ keras_applications \
+ keras_preprocessing \
+ matplotlib \
+ mock \
+ numpy \
+ scipy \
+ sklearn \
+ pandas \
+ && test "${USE_PYTHON_3_NOT_2}" -eq 1 && true || ${PIP} --no-cache-dir install \
+ enum34
+
+ # Build and install bazel
+ENV BAZEL_VERSION 0.15.0
+WORKDIR /
+RUN mkdir /bazel && \
+ cd /bazel && \
+ curl -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-dist.zip && \
+ unzip bazel-$BAZEL_VERSION-dist.zip && \
+ bash ./compile.sh && \
+ cp output/bazel /usr/local/bin/ && \
+ rm -rf /bazel && \
+ cd -
+
+COPY bashrc /etc/bash.bashrc
+RUN chmod a+rwx /etc/bash.bashrc
+
+RUN ${PIP} install jupyter matplotlib
+
+RUN mkdir -p /tf/tensorflow-tutorials && chmod -R a+rwx /tf/
+RUN mkdir /.local && chmod a+rwx /.local
+RUN apt-get install -y --no-install-recommends wget
+WORKDIR /tf/tensorflow-tutorials
+RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/basic_classification.ipynb
+RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/basic_text_classification.ipynb
+COPY readme-for-jupyter.md README.md
+RUN apt-get autoremove -y && apt-get remove -y wget
+WORKDIR /tf
+EXPOSE 8888
+
+RUN ${PYTHON} -m ipykernel.kernelspec
+
+CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/tf --ip 0.0.0.0 --no-browser --allow-root"]
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-cpu-ppc64le.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-cpu-ppc64le.Dockerfile
new file mode 100644
index 0000000..cda51c3
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-cpu-ppc64le.Dockerfile
@@ -0,0 +1,108 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+#
+# THIS IS A GENERATED DOCKERFILE.
+#
+# This file was assembled from multiple pieces, whose use is documented
+# throughout. Please refer to the TensorFlow dockerfiles documentation
+# for more information.
+
+ARG UBUNTU_VERSION=16.04
+
+FROM ubuntu:${UBUNTU_VERSION} AS base
+
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ build-essential \
+ curl \
+ git \
+ libcurl3-dev \
+ libfreetype6-dev \
+ libhdf5-serial-dev \
+ libzmq3-dev \
+ pkg-config \
+ rsync \
+ software-properties-common \
+ unzip \
+ zip \
+ zlib1g-dev \
+ openjdk-8-jdk \
+ openjdk-8-jre-headless \
+ && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
+ENV CI_BUILD_PYTHON python
+
+# CACHE_STOP is used to rerun future commands, otherwise cloning tensorflow will be cached and will not pull the most recent version
+ARG CACHE_STOP=1
+# Check out TensorFlow source code if --build_arg CHECKOUT_TF_SRC=1
+ARG CHECKOUT_TF_SRC=0
+RUN test "${CHECKOUT_TF_SRC}" -eq 1 && git clone https://github.com/tensorflow/tensorflow.git /tensorflow_src || true
+
+ARG USE_PYTHON_3_NOT_2
+ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
+ARG PYTHON=python${_PY_SUFFIX}
+ARG PIP=pip${_PY_SUFFIX}
+
+# See http://bugs.python.org/issue19846
+ENV LANG C.UTF-8
+
+RUN apt-get update && apt-get install -y \
+ ${PYTHON} \
+ ${PYTHON}-pip
+
+RUN ${PIP} --no-cache-dir install --upgrade \
+ pip \
+ setuptools
+
+# Some TF tools expect a "python" binary
+RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
+
+RUN apt-get update && apt-get install -y \
+ build-essential \
+ curl \
+ git \
+ openjdk-8-jdk \
+ ${PYTHON}-dev \
+ swig
+
+RUN ${PIP} --no-cache-dir install \
+ Pillow \
+ h5py \
+ keras_applications \
+ keras_preprocessing \
+ matplotlib \
+ mock \
+ numpy \
+ scipy \
+ sklearn \
+ pandas \
+ && test "${USE_PYTHON_3_NOT_2}" -eq 1 && true || ${PIP} --no-cache-dir install \
+ enum34
+
+ # Build and install bazel
+ENV BAZEL_VERSION 0.15.0
+WORKDIR /
+RUN mkdir /bazel && \
+ cd /bazel && \
+ curl -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-dist.zip && \
+ unzip bazel-$BAZEL_VERSION-dist.zip && \
+ bash ./compile.sh && \
+ cp output/bazel /usr/local/bin/ && \
+ rm -rf /bazel && \
+ cd -
+
+COPY bashrc /etc/bash.bashrc
+RUN chmod a+rwx /etc/bash.bashrc
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le-jupyter.Dockerfile
new file mode 100644
index 0000000..d8ee19f
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le-jupyter.Dockerfile
@@ -0,0 +1,151 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+#
+# THIS IS A GENERATED DOCKERFILE.
+#
+# This file was assembled from multiple pieces, whose use is documented
+# throughout. Please refer to the TensorFlow dockerfiles documentation
+# for more information.
+
+ARG UBUNTU_VERSION=16.04
+
+ARG ARCH=
+ARG CUDA=10.0
+FROM nvidia/cuda${ARCH:+-$ARCH}:${CUDA}-base-ubuntu${UBUNTU_VERSION} as base
+ARG CUDNN=7.4.1.5-1
+ARG LIB_DIR_PREFIX=x84_64
+
+# Needed for string substitution
+SHELL ["/bin/bash", "-c"]
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ build-essential \
+ cuda-command-line-tools-${CUDA/./-} \
+ cuda-cublas-dev-${CUDA/./-} \
+ cuda-cudart-dev-${CUDA/./-} \
+ cuda-cufft-dev-${CUDA/./-} \
+ cuda-curand-dev-${CUDA/./-} \
+ cuda-cusolver-dev-${CUDA/./-} \
+ cuda-cusparse-dev-${CUDA/./-} \
+ libcudnn7=${CUDNN}+cuda${CUDA} \
+ libcudnn7-dev=${CUDNN}+cuda${CUDA} \
+ libcurl3-dev \
+ libfreetype6-dev \
+ libhdf5-serial-dev \
+ libzmq3-dev \
+ pkg-config \
+ rsync \
+ software-properties-common \
+ unzip \
+ zip \
+ zlib1g-dev \
+ wget \
+ git \
+ && \
+ find /usr/local/cuda-${CUDA}/lib64/ -type f -name 'lib*_static.a' -not -name 'libcudart_static.a' -delete && \
+ rm /usr/lib/${LIB_DIR_PREFIX}-linux-gnu/libcudnn_static_v7.a
+
+RUN [ ${ARCH} = ppc64le ] || (apt-get update && \
+ apt-get install nvinfer-runtime-trt-repo-ubuntu1604-5.0.2-ga-cuda${CUDA} \
+ && apt-get update \
+ && apt-get install -y --no-install-recommends libnvinfer5=5.0.2-1+cuda${CUDA} \
+ && apt-get clean \
+ && rm -rf /var/lib/apt/lists/*)
+
+# Configure the build for our CUDA configuration.
+ENV CI_BUILD_PYTHON python
+ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH
+ENV TF_NEED_CUDA 1
+ENV TF_NEED_TENSORRT 1
+ENV TF_CUDA_COMPUTE_CAPABILITIES=3.5,5.2,6.0,6.1,7.0
+ENV TF_CUDA_VERSION=${CUDA}
+ENV TF_CUDNN_VERSION=${CUDNN%%.*}
+# CACHE_STOP is used to rerun future commands, otherwise cloning tensorflow will be cached and will not pull the most recent version
+ARG CACHE_STOP=1
+# Check out TensorFlow source code if --build_arg CHECKOUT_TF_SRC=1
+ARG CHECKOUT_TF_SRC=0
+RUN test "${CHECKOUT_TF_SRC}" -eq 1 && git clone https://github.com/tensorflow/tensorflow.git /tensorflow_src || true
+
+ARG USE_PYTHON_3_NOT_2
+ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
+ARG PYTHON=python${_PY_SUFFIX}
+ARG PIP=pip${_PY_SUFFIX}
+
+# See http://bugs.python.org/issue19846
+ENV LANG C.UTF-8
+
+RUN apt-get update && apt-get install -y \
+ ${PYTHON} \
+ ${PYTHON}-pip
+
+RUN ${PIP} --no-cache-dir install --upgrade \
+ pip \
+ setuptools
+
+# Some TF tools expect a "python" binary
+RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
+
+RUN apt-get update && apt-get install -y \
+ build-essential \
+ curl \
+ git \
+ openjdk-8-jdk \
+ ${PYTHON}-dev \
+ swig
+
+RUN ${PIP} --no-cache-dir install \
+ Pillow \
+ h5py \
+ keras_applications \
+ keras_preprocessing \
+ matplotlib \
+ mock \
+ numpy \
+ scipy \
+ sklearn \
+ pandas \
+ && test "${USE_PYTHON_3_NOT_2}" -eq 1 && true || ${PIP} --no-cache-dir install \
+ enum34
+
+ # Build and install bazel
+ENV BAZEL_VERSION 0.15.0
+WORKDIR /
+RUN mkdir /bazel && \
+ cd /bazel && \
+ curl -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-dist.zip && \
+ unzip bazel-$BAZEL_VERSION-dist.zip && \
+ bash ./compile.sh && \
+ cp output/bazel /usr/local/bin/ && \
+ rm -rf /bazel && \
+ cd -
+
+COPY bashrc /etc/bash.bashrc
+RUN chmod a+rwx /etc/bash.bashrc
+
+RUN ${PIP} install jupyter matplotlib
+
+RUN mkdir -p /tf/tensorflow-tutorials && chmod -R a+rwx /tf/
+RUN mkdir /.local && chmod a+rwx /.local
+RUN apt-get install -y --no-install-recommends wget
+WORKDIR /tf/tensorflow-tutorials
+RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/basic_classification.ipynb
+RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/basic_text_classification.ipynb
+COPY readme-for-jupyter.md README.md
+RUN apt-get autoremove -y && apt-get remove -y wget
+WORKDIR /tf
+EXPOSE 8888
+
+RUN ${PYTHON} -m ipykernel.kernelspec
+
+CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/tf --ip 0.0.0.0 --no-browser --allow-root"]
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le.Dockerfile
new file mode 100644
index 0000000..9660706
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/devel-gpu-ppc64le.Dockerfile
@@ -0,0 +1,134 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+#
+# THIS IS A GENERATED DOCKERFILE.
+#
+# This file was assembled from multiple pieces, whose use is documented
+# throughout. Please refer to the TensorFlow dockerfiles documentation
+# for more information.
+
+ARG UBUNTU_VERSION=16.04
+
+ARG ARCH=
+ARG CUDA=10.0
+FROM nvidia/cuda${ARCH:+-$ARCH}:${CUDA}-base-ubuntu${UBUNTU_VERSION} as base
+ARG CUDNN=7.4.1.5-1
+ARG LIB_DIR_PREFIX=x84_64
+
+# Needed for string substitution
+SHELL ["/bin/bash", "-c"]
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ build-essential \
+ cuda-command-line-tools-${CUDA/./-} \
+ cuda-cublas-dev-${CUDA/./-} \
+ cuda-cudart-dev-${CUDA/./-} \
+ cuda-cufft-dev-${CUDA/./-} \
+ cuda-curand-dev-${CUDA/./-} \
+ cuda-cusolver-dev-${CUDA/./-} \
+ cuda-cusparse-dev-${CUDA/./-} \
+ libcudnn7=${CUDNN}+cuda${CUDA} \
+ libcudnn7-dev=${CUDNN}+cuda${CUDA} \
+ libcurl3-dev \
+ libfreetype6-dev \
+ libhdf5-serial-dev \
+ libzmq3-dev \
+ pkg-config \
+ rsync \
+ software-properties-common \
+ unzip \
+ zip \
+ zlib1g-dev \
+ wget \
+ git \
+ && \
+ find /usr/local/cuda-${CUDA}/lib64/ -type f -name 'lib*_static.a' -not -name 'libcudart_static.a' -delete && \
+ rm /usr/lib/${LIB_DIR_PREFIX}-linux-gnu/libcudnn_static_v7.a
+
+RUN [ ${ARCH} = ppc64le ] || (apt-get update && \
+ apt-get install nvinfer-runtime-trt-repo-ubuntu1604-5.0.2-ga-cuda${CUDA} \
+ && apt-get update \
+ && apt-get install -y --no-install-recommends libnvinfer5=5.0.2-1+cuda${CUDA} \
+ && apt-get clean \
+ && rm -rf /var/lib/apt/lists/*)
+
+# Configure the build for our CUDA configuration.
+ENV CI_BUILD_PYTHON python
+ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH
+ENV TF_NEED_CUDA 1
+ENV TF_NEED_TENSORRT 1
+ENV TF_CUDA_COMPUTE_CAPABILITIES=3.5,5.2,6.0,6.1,7.0
+ENV TF_CUDA_VERSION=${CUDA}
+ENV TF_CUDNN_VERSION=${CUDNN%%.*}
+# CACHE_STOP is used to rerun future commands, otherwise cloning tensorflow will be cached and will not pull the most recent version
+ARG CACHE_STOP=1
+# Check out TensorFlow source code if --build_arg CHECKOUT_TF_SRC=1
+ARG CHECKOUT_TF_SRC=0
+RUN test "${CHECKOUT_TF_SRC}" -eq 1 && git clone https://github.com/tensorflow/tensorflow.git /tensorflow_src || true
+
+ARG USE_PYTHON_3_NOT_2
+ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
+ARG PYTHON=python${_PY_SUFFIX}
+ARG PIP=pip${_PY_SUFFIX}
+
+# See http://bugs.python.org/issue19846
+ENV LANG C.UTF-8
+
+RUN apt-get update && apt-get install -y \
+ ${PYTHON} \
+ ${PYTHON}-pip
+
+RUN ${PIP} --no-cache-dir install --upgrade \
+ pip \
+ setuptools
+
+# Some TF tools expect a "python" binary
+RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
+
+RUN apt-get update && apt-get install -y \
+ build-essential \
+ curl \
+ git \
+ openjdk-8-jdk \
+ ${PYTHON}-dev \
+ swig
+
+RUN ${PIP} --no-cache-dir install \
+ Pillow \
+ h5py \
+ keras_applications \
+ keras_preprocessing \
+ matplotlib \
+ mock \
+ numpy \
+ scipy \
+ sklearn \
+ pandas \
+ && test "${USE_PYTHON_3_NOT_2}" -eq 1 && true || ${PIP} --no-cache-dir install \
+ enum34
+
+ # Build and install bazel
+ENV BAZEL_VERSION 0.15.0
+WORKDIR /
+RUN mkdir /bazel && \
+ cd /bazel && \
+ curl -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-dist.zip && \
+ unzip bazel-$BAZEL_VERSION-dist.zip && \
+ bash ./compile.sh && \
+ cp output/bazel /usr/local/bin/ && \
+ rm -rf /bazel && \
+ cd -
+
+COPY bashrc /etc/bash.bashrc
+RUN chmod a+rwx /etc/bash.bashrc
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le-jupyter.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le-jupyter.Dockerfile
new file mode 100644
index 0000000..449a8d8
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le-jupyter.Dockerfile
@@ -0,0 +1,125 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+#
+# THIS IS A GENERATED DOCKERFILE.
+#
+# This file was assembled from multiple pieces, whose use is documented
+# throughout. Please refer to the TensorFlow dockerfiles documentation
+# for more information.
+
+ARG UBUNTU_VERSION=16.04
+
+ARG ARCH=
+ARG CUDA=10.0
+FROM nvidia/cuda${ARCH:+-$ARCH}:${CUDA}-base-ubuntu${UBUNTU_VERSION} as base
+ARG CUDNN=7.4.1.5-1
+
+# Needed for string substitution
+SHELL ["/bin/bash", "-c"]
+# Pick up some TF dependencies
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ build-essential \
+ cuda-command-line-tools-${CUDA/./-} \
+ cuda-cublas-${CUDA/./-} \
+ cuda-cufft-${CUDA/./-} \
+ cuda-curand-${CUDA/./-} \
+ cuda-cusolver-${CUDA/./-} \
+ cuda-cusparse-${CUDA/./-} \
+ curl \
+ libcudnn7=${CUDNN}+cuda${CUDA} \
+ libfreetype6-dev \
+ libhdf5-serial-dev \
+ libzmq3-dev \
+ pkg-config \
+ software-properties-common \
+ unzip
+
+RUN [ ${ARCH} = ppc64le ] || (apt-get update && \
+ apt-get install nvinfer-runtime-trt-repo-ubuntu1604-5.0.2-ga-cuda${CUDA} \
+ && apt-get update \
+ && apt-get install -y --no-install-recommends libnvinfer5=5.0.2-1+cuda${CUDA} \
+ && apt-get clean \
+ && rm -rf /var/lib/apt/lists/*)
+
+# For CUDA profiling, TensorFlow requires CUPTI.
+ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH
+
+ARG USE_PYTHON_3_NOT_2
+ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
+ARG PYTHON=python${_PY_SUFFIX}
+ARG PIP=pip${_PY_SUFFIX}
+
+# See http://bugs.python.org/issue19846
+ENV LANG C.UTF-8
+
+RUN apt-get update && apt-get install -y \
+ ${PYTHON} \
+ ${PYTHON}-pip
+
+RUN ${PIP} --no-cache-dir install --upgrade \
+ pip \
+ setuptools
+
+# Some TF tools expect a "python" binary
+RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
+
+# Options:
+# tensorflow
+# tensorflow-gpu
+# tf-nightly
+# tf-nightly-gpu
+ARG TF_PACKAGE=tensorflow
+RUN apt-get update && apt-get install -y wget libhdf5-dev
+RUN ${PIP} install --global-option=build_ext \
+ --global-option=-I/usr/include/hdf5/serial/ \
+ --global-option=-L/usr/lib/powerpc64le-linux-gnu/hdf5/serial \
+ h5py
+
+# CACHE_STOP is used to rerun future commands, otherwise downloading the .whl will be cached and will not pull the most recent version
+ARG CACHE_STOP=1
+RUN if [ ${TF_PACKAGE} = tensorflow-gpu ]; then \
+ BASE=https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/lastSuccessfulBuild/; \
+ elif [ ${TF_PACKAGE} = tf-nightly-gpu ]; then \
+ BASE=https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/lastSuccessfulBuild/; \
+ elif [ ${TF_PACKAGE} = tensorflow ]; then \
+ BASE=https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/lastSuccessfulBuild/; \
+ elif [ ${TF_PACKAGE} = tf-nightly ]; then \
+ BASE=https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/lastSuccessfulBuild/; \
+ fi; \
+ MAJOR=`${PYTHON} -c 'import sys; print(sys.version_info[0])'`; \
+ MINOR=`${PYTHON} -c 'import sys; print(sys.version_info[1])'`; \
+ PACKAGE=$(wget -qO- ${BASE}"api/xml?xpath=//fileName&wrapper=artifacts" | grep -o "[^<>]*cp${MAJOR}${MINOR}[^<>]*.whl"); \
+ wget ${BASE}"artifact/tensorflow_pkg/"${PACKAGE}; \
+ ${PIP} install ${PACKAGE}
+
+COPY bashrc /etc/bash.bashrc
+RUN chmod a+rwx /etc/bash.bashrc
+
+RUN ${PIP} install jupyter matplotlib
+
+RUN mkdir -p /tf/tensorflow-tutorials && chmod -R a+rwx /tf/
+RUN mkdir /.local && chmod a+rwx /.local
+RUN apt-get install -y --no-install-recommends wget
+WORKDIR /tf/tensorflow-tutorials
+RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/basic_classification.ipynb
+RUN wget https://raw.githubusercontent.com/tensorflow/docs/master/site/en/tutorials/keras/basic_text_classification.ipynb
+COPY readme-for-jupyter.md README.md
+RUN apt-get autoremove -y && apt-get remove -y wget
+WORKDIR /tf
+EXPOSE 8888
+
+RUN ${PYTHON} -m ipykernel.kernelspec
+
+CMD ["bash", "-c", "source /etc/bash.bashrc && jupyter notebook --notebook-dir=/tf --ip 0.0.0.0 --no-browser --allow-root"]
diff --git a/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le.Dockerfile b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le.Dockerfile
new file mode 100644
index 0000000..f01a47f
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/dockerfiles/ppc64le/gpu-ppc64le.Dockerfile
@@ -0,0 +1,108 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+#
+# THIS IS A GENERATED DOCKERFILE.
+#
+# This file was assembled from multiple pieces, whose use is documented
+# throughout. Please refer to the TensorFlow dockerfiles documentation
+# for more information.
+
+ARG UBUNTU_VERSION=16.04
+
+ARG ARCH=
+ARG CUDA=10.0
+FROM nvidia/cuda${ARCH:+-$ARCH}:${CUDA}-base-ubuntu${UBUNTU_VERSION} as base
+ARG CUDNN=7.4.1.5-1
+
+# Needed for string substitution
+SHELL ["/bin/bash", "-c"]
+# Pick up some TF dependencies
+RUN apt-get update && apt-get install -y --no-install-recommends \
+ build-essential \
+ cuda-command-line-tools-${CUDA/./-} \
+ cuda-cublas-${CUDA/./-} \
+ cuda-cufft-${CUDA/./-} \
+ cuda-curand-${CUDA/./-} \
+ cuda-cusolver-${CUDA/./-} \
+ cuda-cusparse-${CUDA/./-} \
+ curl \
+ libcudnn7=${CUDNN}+cuda${CUDA} \
+ libfreetype6-dev \
+ libhdf5-serial-dev \
+ libzmq3-dev \
+ pkg-config \
+ software-properties-common \
+ unzip
+
+RUN [ ${ARCH} = ppc64le ] || (apt-get update && \
+ apt-get install nvinfer-runtime-trt-repo-ubuntu1604-5.0.2-ga-cuda${CUDA} \
+ && apt-get update \
+ && apt-get install -y --no-install-recommends libnvinfer5=5.0.2-1+cuda${CUDA} \
+ && apt-get clean \
+ && rm -rf /var/lib/apt/lists/*)
+
+# For CUDA profiling, TensorFlow requires CUPTI.
+ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH
+
+ARG USE_PYTHON_3_NOT_2
+ARG _PY_SUFFIX=${USE_PYTHON_3_NOT_2:+3}
+ARG PYTHON=python${_PY_SUFFIX}
+ARG PIP=pip${_PY_SUFFIX}
+
+# See http://bugs.python.org/issue19846
+ENV LANG C.UTF-8
+
+RUN apt-get update && apt-get install -y \
+ ${PYTHON} \
+ ${PYTHON}-pip
+
+RUN ${PIP} --no-cache-dir install --upgrade \
+ pip \
+ setuptools
+
+# Some TF tools expect a "python" binary
+RUN ln -s $(which ${PYTHON}) /usr/local/bin/python
+
+# Options:
+# tensorflow
+# tensorflow-gpu
+# tf-nightly
+# tf-nightly-gpu
+ARG TF_PACKAGE=tensorflow
+RUN apt-get update && apt-get install -y wget libhdf5-dev
+RUN ${PIP} install --global-option=build_ext \
+ --global-option=-I/usr/include/hdf5/serial/ \
+ --global-option=-L/usr/lib/powerpc64le-linux-gnu/hdf5/serial \
+ h5py
+
+# CACHE_STOP is used to rerun future commands, otherwise downloading the .whl will be cached and will not pull the most recent version
+ARG CACHE_STOP=1
+RUN if [ ${TF_PACKAGE} = tensorflow-gpu ]; then \
+ BASE=https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/lastSuccessfulBuild/; \
+ elif [ ${TF_PACKAGE} = tf-nightly-gpu ]; then \
+ BASE=https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/lastSuccessfulBuild/; \
+ elif [ ${TF_PACKAGE} = tensorflow ]; then \
+ BASE=https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/lastSuccessfulBuild/; \
+ elif [ ${TF_PACKAGE} = tf-nightly ]; then \
+ BASE=https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/lastSuccessfulBuild/; \
+ fi; \
+ MAJOR=`${PYTHON} -c 'import sys; print(sys.version_info[0])'`; \
+ MINOR=`${PYTHON} -c 'import sys; print(sys.version_info[1])'`; \
+ PACKAGE=$(wget -qO- ${BASE}"api/xml?xpath=//fileName&wrapper=artifacts" | grep -o "[^<>]*cp${MAJOR}${MINOR}[^<>]*.whl"); \
+ wget ${BASE}"artifact/tensorflow_pkg/"${PACKAGE}; \
+ ${PIP} install ${PACKAGE}
+
+COPY bashrc /etc/bash.bashrc
+RUN chmod a+rwx /etc/bash.bashrc
diff --git a/tensorflow/tools/dockerfiles/partials/jupyter.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/jupyter.partial.Dockerfile
index c4ec609..c056d91 100644
--- a/tensorflow/tools/dockerfiles/partials/jupyter.partial.Dockerfile
+++ b/tensorflow/tools/dockerfiles/partials/jupyter.partial.Dockerfile
@@ -1,4 +1,6 @@
RUN ${PIP} install jupyter matplotlib
+RUN ${PIP} install jupyter_http_over_ws
+RUN jupyter serverextension enable --py jupyter_http_over_ws
RUN mkdir -p /tf/tensorflow-tutorials && chmod -R a+rwx /tf/
RUN mkdir /.local && chmod a+rwx /.local
diff --git a/tensorflow/tools/dockerfiles/partials/tensorflow-ppc64le.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/tensorflow-ppc64le.partial.Dockerfile
new file mode 100644
index 0000000..1e79574
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/partials/tensorflow-ppc64le.partial.Dockerfile
@@ -0,0 +1,28 @@
+# Options:
+# tensorflow
+# tensorflow-gpu
+# tf-nightly
+# tf-nightly-gpu
+ARG TF_PACKAGE=tensorflow
+RUN apt-get update && apt-get install -y wget libhdf5-dev
+RUN ${PIP} install --global-option=build_ext \
+ --global-option=-I/usr/include/hdf5/serial/ \
+ --global-option=-L/usr/lib/powerpc64le-linux-gnu/hdf5/serial \
+ h5py
+
+# CACHE_STOP is used to rerun future commands, otherwise downloading the .whl will be cached and will not pull the most recent version
+ARG CACHE_STOP=1
+RUN if [ ${TF_PACKAGE} = tensorflow-gpu ]; then \
+ BASE=https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Release_Build/lastSuccessfulBuild/; \
+ elif [ ${TF_PACKAGE} = tf-nightly-gpu ]; then \
+ BASE=https://powerci.osuosl.org/job/TensorFlow_PPC64LE_GPU_Nightly_Artifact/lastSuccessfulBuild/; \
+ elif [ ${TF_PACKAGE} = tensorflow ]; then \
+ BASE=https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Release_Build/lastSuccessfulBuild/; \
+ elif [ ${TF_PACKAGE} = tf-nightly ]; then \
+ BASE=https://powerci.osuosl.org/job/TensorFlow_PPC64LE_CPU_Nightly_Artifact/lastSuccessfulBuild/; \
+ fi; \
+ MAJOR=`${PYTHON} -c 'import sys; print(sys.version_info[0])'`; \
+ MINOR=`${PYTHON} -c 'import sys; print(sys.version_info[1])'`; \
+ PACKAGE=$(wget -qO- ${BASE}"api/xml?xpath=//fileName&wrapper=artifacts" | grep -o "[^<>]*cp${MAJOR}${MINOR}[^<>]*.whl"); \
+ wget ${BASE}"artifact/tensorflow_pkg/"${PACKAGE}; \
+ ${PIP} install ${PACKAGE}
diff --git a/tensorflow/tools/dockerfiles/partials/ubuntu/bazelbuild.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/ubuntu/bazelbuild.partial.Dockerfile
new file mode 100644
index 0000000..0397ab5
--- /dev/null
+++ b/tensorflow/tools/dockerfiles/partials/ubuntu/bazelbuild.partial.Dockerfile
@@ -0,0 +1,33 @@
+RUN apt-get update && apt-get install -y \
+ build-essential \
+ curl \
+ git \
+ openjdk-8-jdk \
+ ${PYTHON}-dev \
+ swig
+
+RUN ${PIP} --no-cache-dir install \
+ Pillow \
+ h5py \
+ keras_applications \
+ keras_preprocessing \
+ matplotlib \
+ mock \
+ numpy \
+ scipy \
+ sklearn \
+ pandas \
+ && test "${USE_PYTHON_3_NOT_2}" -eq 1 && true || ${PIP} --no-cache-dir install \
+ enum34
+
+ # Build and install bazel
+ENV BAZEL_VERSION 0.15.0
+WORKDIR /
+RUN mkdir /bazel && \
+ cd /bazel && \
+ curl -fSsL -O https://github.com/bazelbuild/bazel/releases/download/$BAZEL_VERSION/bazel-$BAZEL_VERSION-dist.zip && \
+ unzip bazel-$BAZEL_VERSION-dist.zip && \
+ bash ./compile.sh && \
+ cp output/bazel /usr/local/bin/ && \
+ rm -rf /bazel && \
+ cd -
diff --git a/tensorflow/tools/dockerfiles/partials/ubuntu/devel-cpu.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/ubuntu/devel-cpu.partial.Dockerfile
index 0652ac4..a1fd901 100644
--- a/tensorflow/tools/dockerfiles/partials/ubuntu/devel-cpu.partial.Dockerfile
+++ b/tensorflow/tools/dockerfiles/partials/ubuntu/devel-cpu.partial.Dockerfile
@@ -7,7 +7,6 @@
libcurl3-dev \
libfreetype6-dev \
libhdf5-serial-dev \
- libpng12-dev \
libzmq3-dev \
pkg-config \
rsync \
@@ -20,9 +19,11 @@
&& \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
-
+
ENV CI_BUILD_PYTHON python
+# CACHE_STOP is used to rerun future commands, otherwise cloning tensorflow will be cached and will not pull the most recent version
+ARG CACHE_STOP=1
# Check out TensorFlow source code if --build-arg CHECKOUT_TF_SRC=1
ARG CHECKOUT_TF_SRC=0
-RUN test "${CHECKOUT_TF_SRC}" -eq 1 && git clone https://github.com/tensorflow/tensorflow.git /tensorflow_src
+RUN test "${CHECKOUT_TF_SRC}" -eq 1 && git clone https://github.com/tensorflow/tensorflow.git /tensorflow_src || true
diff --git a/tensorflow/tools/dockerfiles/partials/ubuntu/devel-nvidia.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/ubuntu/devel-nvidia.partial.Dockerfile
index 2b4494a..cb153fa 100644
--- a/tensorflow/tools/dockerfiles/partials/ubuntu/devel-nvidia.partial.Dockerfile
+++ b/tensorflow/tools/dockerfiles/partials/ubuntu/devel-nvidia.partial.Dockerfile
@@ -1,20 +1,25 @@
-FROM nvidia/cuda:10.0-base-ubuntu${UBUNTU_VERSION} as base
+ARG ARCH=
+ARG CUDA=10.0
+FROM nvidia/cuda${ARCH:+-$ARCH}:${CUDA}-base-ubuntu${UBUNTU_VERSION} as base
+ARG CUDNN=7.4.1.5-1
+ARG LIB_DIR_PREFIX=x84_64
+# Needed for string substitution
+SHELL ["/bin/bash", "-c"]
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
- cuda-command-line-tools-10-0 \
- cuda-cublas-dev-10-0 \
- cuda-cudart-dev-10-0 \
- cuda-cufft-dev-10-0 \
- cuda-curand-dev-10-0 \
- cuda-cusolver-dev-10-0 \
- cuda-cusparse-dev-10-0 \
- libcudnn7=7.4.1.5-1+cuda10.0 \
- libcudnn7-dev=7.4.1.5-1+cuda10.0 \
+ cuda-command-line-tools-${CUDA/./-} \
+ cuda-cublas-dev-${CUDA/./-} \
+ cuda-cudart-dev-${CUDA/./-} \
+ cuda-cufft-dev-${CUDA/./-} \
+ cuda-curand-dev-${CUDA/./-} \
+ cuda-cusolver-dev-${CUDA/./-} \
+ cuda-cusparse-dev-${CUDA/./-} \
+ libcudnn7=${CUDNN}+cuda${CUDA} \
+ libcudnn7-dev=${CUDNN}+cuda${CUDA} \
libcurl3-dev \
libfreetype6-dev \
libhdf5-serial-dev \
- libpng12-dev \
libzmq3-dev \
pkg-config \
rsync \
@@ -25,14 +30,15 @@
wget \
git \
&& \
- find /usr/local/cuda-10.0/lib64/ -type f -name 'lib*_static.a' -not -name 'libcudart_static.a' -delete && \
- rm /usr/lib/x86_64-linux-gnu/libcudnn_static_v7.a
+ find /usr/local/cuda-${CUDA}/lib64/ -type f -name 'lib*_static.a' -not -name 'libcudart_static.a' -delete && \
+ rm /usr/lib/${LIB_DIR_PREFIX}-linux-gnu/libcudnn_static_v7.a
-RUN apt-get update && \
- apt-get install nvinfer-runtime-trt-repo-ubuntu1604-5.0.2-ga-cuda10.0 \
+RUN [ ${ARCH} = ppc64le ] || (apt-get update && \
+ apt-get install nvinfer-runtime-trt-repo-ubuntu1604-5.0.2-ga-cuda${CUDA} \
&& apt-get update \
- && apt-get install -y --no-install-recommends libnvinfer-dev=5.0.2-1+cuda10.0 \
- && rm -rf /var/lib/apt/lists/*
+ && apt-get install -y --no-install-recommends libnvinfer5=5.0.2-1+cuda${CUDA} \
+ && apt-get clean \
+ && rm -rf /var/lib/apt/lists/*)
# Configure the build for our CUDA configuration.
ENV CI_BUILD_PYTHON python
@@ -40,9 +46,10 @@
ENV TF_NEED_CUDA 1
ENV TF_NEED_TENSORRT 1
ENV TF_CUDA_COMPUTE_CAPABILITIES=3.5,5.2,6.0,6.1,7.0
-ENV TF_CUDA_VERSION=10.0
-ENV TF_CUDNN_VERSION=7
-
-# Check out TensorFlow source code if --build_arg CHECKOUT_TENSORFLOW=1
+ENV TF_CUDA_VERSION=${CUDA}
+ENV TF_CUDNN_VERSION=${CUDNN%%.*}
+# CACHE_STOP is used to rerun future commands, otherwise cloning tensorflow will be cached and will not pull the most recent version
+ARG CACHE_STOP=1
+# Check out TensorFlow source code if --build-arg CHECKOUT_TF_SRC=1
ARG CHECKOUT_TF_SRC=0
-RUN test "${CHECKOUT_TF_SRC}" -eq 1 && git clone https://github.com/tensorflow/tensorflow.git /tensorflow_src
+RUN test "${CHECKOUT_TF_SRC}" -eq 1 && git clone https://github.com/tensorflow/tensorflow.git /tensorflow_src || true
diff --git a/tensorflow/tools/dockerfiles/partials/ubuntu/nvidia.partial.Dockerfile b/tensorflow/tools/dockerfiles/partials/ubuntu/nvidia.partial.Dockerfile
index a6393a3..1d40ed5 100644
--- a/tensorflow/tools/dockerfiles/partials/ubuntu/nvidia.partial.Dockerfile
+++ b/tensorflow/tools/dockerfiles/partials/ubuntu/nvidia.partial.Dockerfile
@@ -1,29 +1,34 @@
-FROM nvidia/cuda:10.0-base-ubuntu${UBUNTU_VERSION} as base
+ARG ARCH=
+ARG CUDA=10.0
+FROM nvidia/cuda${ARCH:+-$ARCH}:${CUDA}-base-ubuntu${UBUNTU_VERSION} as base
+ARG CUDNN=7.4.1.5-1
+# Needed for string substitution
+SHELL ["/bin/bash", "-c"]
# Pick up some TF dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
- cuda-command-line-tools-10-0 \
- cuda-cublas-10-0 \
- cuda-cufft-10-0 \
- cuda-curand-10-0 \
- cuda-cusolver-10-0 \
- cuda-cusparse-10-0 \
- libcudnn7=7.4.1.5-1+cuda10.0 \
+ cuda-command-line-tools-${CUDA/./-} \
+ cuda-cublas-${CUDA/./-} \
+ cuda-cufft-${CUDA/./-} \
+ cuda-curand-${CUDA/./-} \
+ cuda-cusolver-${CUDA/./-} \
+ cuda-cusparse-${CUDA/./-} \
+ curl \
+ libcudnn7=${CUDNN}+cuda${CUDA} \
libfreetype6-dev \
libhdf5-serial-dev \
- libpng12-dev \
libzmq3-dev \
pkg-config \
software-properties-common \
unzip
-RUN apt-get update && \
- apt-get install nvinfer-runtime-trt-repo-ubuntu1604-5.0.2-ga-cuda10.0 \
+RUN [ ${ARCH} = ppc64le ] || (apt-get update && \
+ apt-get install nvinfer-runtime-trt-repo-ubuntu1604-5.0.2-ga-cuda${CUDA} \
&& apt-get update \
- && apt-get install -y --no-install-recommends libnvinfer5=5.0.2-1+cuda10.0 \
+ && apt-get install -y --no-install-recommends libnvinfer5=5.0.2-1+cuda${CUDA} \
&& apt-get clean \
- && rm -rf /var/lib/apt/lists/*
+ && rm -rf /var/lib/apt/lists/*)
# For CUDA profiling, TensorFlow requires CUPTI.
ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH
diff --git a/tensorflow/tools/dockerfiles/spec.yml b/tensorflow/tools/dockerfiles/spec.yml
index 19d96e7..d19b1d1 100644
--- a/tensorflow/tools/dockerfiles/spec.yml
+++ b/tensorflow/tools/dockerfiles/spec.yml
@@ -56,6 +56,13 @@
- "{ubuntu}{jupyter}"
- "{ubuntu-devel}{jupyter}"
+ ppc64le-dockerfiles:
+ is_dockerfiles: true
+ upload_images: false
+ tag_specs:
+ - "{ubuntu-ppc64le}{jupyter}"
+ - "{ubuntu-devel-ppc64le}{jupyter}"
+
slice_sets:
py:
@@ -122,6 +129,70 @@
args:
- CHECKOUT_TF_SRC=1
+ ubuntu-ppc64le:
+ - add_to_name: "-ppc64le"
+ dockerfile_exclusive_name: "cpu-ppc64le"
+ dockerfile_subdirectory: "ppc64le"
+ args:
+ - UBUNTU_VERSION=18.04
+ partials:
+ - ubuntu/version
+ - ubuntu/cpu
+ - ubuntu/python
+ - tensorflow-ppc64le
+ - shell
+ - add_to_name: "-gpu-ppc64le"
+ dockerfile_exclusive_name: "gpu-ppc64le"
+ dockerfile_subdirectory: "ppc64le"
+ args:
+ - UBUNTU_VERSION=18.04
+ - ARCH=ppc64le
+ - CUDA=10.0
+ - TF_PACKAGE=tensorflow-gpu
+ partials:
+ - ubuntu/version
+ - ubuntu/nvidia
+ - ubuntu/python
+ - tensorflow-ppc64le
+ - shell
+ tests:
+ - import-gpu.sh
+ test_runtime: nvidia
+
+ ubuntu-devel-ppc64le:
+ - add_to_name: "devel-ppc64le"
+ dockerfile_exclusive_name: "devel-cpu-ppc64le"
+ dockerfile_subdirectory: "ppc64le"
+ partials:
+ - ubuntu/version
+ - ubuntu/devel-cpu
+ - ubuntu/python
+ - ubuntu/bazelbuild
+ - shell
+ tests:
+ - build-cpu.sh
+ args:
+ - UBUNTU_VERSION=18.04
+ - CHECKOUT_TF_SRC=1
+ - add_to_name: "devel-gpu-ppc64le"
+ dockerfile_exclusive_name: "devel-gpu-ppc64le"
+ dockerfile_subdirectory: "ppc64le"
+ args:
+ - UBUNTU_VERSION=18.04
+ - ARCH=ppc64le
+ - CUDA=10.0
+ - LIB_DIR_PREFIX=powerpc64le
+ - CHECKOUT_TF_SRC=1
+ partials:
+ - ubuntu/version
+ - ubuntu/devel-nvidia
+ - ubuntu/python
+ - ubuntu/bazelbuild
+ - shell
+ tests:
+ - build-gpu.sh
+ test_runtime: nvidia
+
nightly:
- add_to_name: "nightly"
partials:
diff --git a/tensorflow/tools/dockerfiles/tools.Dockerfile b/tensorflow/tools/dockerfiles/tools.Dockerfile
index e892929..a96b257 100644
--- a/tensorflow/tools/dockerfiles/tools.Dockerfile
+++ b/tensorflow/tools/dockerfiles/tools.Dockerfile
@@ -17,7 +17,7 @@
#
# You can use this image to quickly develop changes to the Dockerfile assembler
# or set of TF Docker partials. See README.md for usage instructions.
-FROM debian:stretch
+FROM ubuntu:16.04
LABEL maintainer="Austin Anderson <angerson@google.com>"
RUN apt-get update && apt-get install -y python3 python3-pip bash curl
diff --git a/tensorflow/tools/git/gen_git_source.py b/tensorflow/tools/git/gen_git_source.py
index c2449da..4d52c1f 100755
--- a/tensorflow/tools/git/gen_git_source.py
+++ b/tensorflow/tools/git/gen_git_source.py
@@ -189,7 +189,7 @@
git_version: the result of a git describe.
"""
if b"\"" in git_version or b"\\" in git_version:
- git_version = "git_version_is_invalid" # do not cause build to fail!
+ git_version = b"git_version_is_invalid" # do not cause build to fail!
contents = """/* Generated by gen_git_source.py */
#include <string>
const char* tf_git_version() {return "%s";}
@@ -216,7 +216,7 @@
return 0;
#endif
}
-""" % git_version
+""" % git_version.decode("utf-8")
open(filename, "w").write(contents)
diff --git a/tensorflow/tools/graph_transforms/BUILD b/tensorflow/tools/graph_transforms/BUILD
index 41ed31a..2145b3b 100644
--- a/tensorflow/tools/graph_transforms/BUILD
+++ b/tensorflow/tools/graph_transforms/BUILD
@@ -136,6 +136,15 @@
] + if_not_windows([
"//tensorflow/core/kernels:remote_fused_graph_rewriter_transform",
"//tensorflow/core/kernels/hexagon:hexagon_rewriter_transform",
+ "//tensorflow/core:sparse_ops_op_lib",
+ "//tensorflow/core:parsing_ops_op_lib",
+ "//tensorflow/core:sendrecv_ops_op_lib",
+ "//tensorflow/core:io_ops_op_lib",
+ "//tensorflow/core:logging_ops_op_lib",
+ "//tensorflow/core:lookup_ops_op_lib",
+ "//tensorflow/core:data_flow_ops_op_lib",
+ "//tensorflow/core:no_op_op_lib",
+ "//tensorflow/core:state_ops_op_lib",
"//tensorflow/core:user_ops_op_lib",
"//tensorflow/core:training_ops_op_lib",
"//tensorflow/core:string_ops_op_lib",
diff --git a/tensorflow/tools/graph_transforms/fold_batch_norms.cc b/tensorflow/tools/graph_transforms/fold_batch_norms.cc
index 16a0f7d..f59a7ab 100644
--- a/tensorflow/tools/graph_transforms/fold_batch_norms.cc
+++ b/tensorflow/tools/graph_transforms/fold_batch_norms.cc
@@ -37,7 +37,7 @@
input_graph_def, // clang-format off
{"Mul", // mul_node
{
- {"Conv2D|MatMul", // conv_node
+ {"Conv2D|MatMul|DepthwiseConv2dNative", // conv_node
{
{"*"}, // input_node
{"Const"}, // weights_node
@@ -72,8 +72,15 @@
// Make sure all the inputs really are vectors, with as many entries as
// there are columns in the weights.
- const int weights_cols_index = conv_node.op() == "Conv2D" ? 3 : 1;
- const int64 weights_cols = weights.shape().dim_size(weights_cols_index);
+ int64 weights_cols;
+ if (conv_node.op() == "Conv2D") {
+ weights_cols = weights.shape().dim_size(3);
+ } else if (conv_node.op() == "DepthwiseConv2dNative") {
+ weights_cols =
+ weights.shape().dim_size(2) * weights.shape().dim_size(3);
+ } else {
+ weights_cols = weights.shape().dim_size(1);
+ }
if ((mul_values.shape().dims() != 1) ||
(mul_values.shape().dim_size(0) != weights_cols)) {
return errors::InvalidArgument(
@@ -82,14 +89,13 @@
}
// Multiply the original weights by the scale vector.
- auto weights_matrix = weights.flat_inner_dims<float>();
+ auto weights_vector = weights.flat<float>();
Tensor scaled_weights(DT_FLOAT, weights.shape());
- auto scaled_weights_matrix = scaled_weights.flat_inner_dims<float>();
- for (int64 row = 0; row < weights_matrix.dimension(0); ++row) {
- for (int64 col = 0; col < weights_cols; ++col) {
- scaled_weights_matrix(row, col) =
- weights_matrix(row, col) * mul_values.flat<float>()(col);
- }
+ auto scaled_weights_vector = scaled_weights.flat<float>();
+ for (int64 row = 0; row < weights_vector.dimension(0); ++row) {
+ scaled_weights_vector(row) =
+ weights_vector(row) *
+ mul_values.flat<float>()(row % weights_cols);
}
// Construct the new nodes.
diff --git a/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc b/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc
index a5d541f..885fbd5 100644
--- a/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc
+++ b/tensorflow/tools/graph_transforms/fold_batch_norms_test.cc
@@ -87,6 +87,57 @@
}
}
+ void TestFoldBatchNormsDepthwiseConv2dNative() {
+ auto root = tensorflow::Scope::NewRootScope();
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+
+ Tensor input_data(DT_FLOAT, TensorShape({1, 1, 6, 2}));
+ test::FillValues<float>(
+ &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
+ -5.0f, -3.0f, -6.0f});
+ Output input_op =
+ Const(root.WithOpName("input_op"), Input::Initializer(input_data));
+
+ Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2}));
+ test::FillValues<float>(&weights_data,
+ {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
+ Output weights_op =
+ Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
+
+ Output conv_op = DepthwiseConv2dNative(root.WithOpName("conv_op"), input_op,
+ weights_op, {1, 1, 1, 1}, "VALID");
+
+ Tensor mul_values_data(DT_FLOAT, TensorShape({4}));
+ test::FillValues<float>(&mul_values_data, {2.0f, 3.0f, 4.0f, 5.0f});
+ Output mul_values_op = Const(root.WithOpName("mul_values"),
+ Input::Initializer(mul_values_data));
+
+ Output mul_op = Mul(root.WithOpName("output"), conv_op, mul_values_op);
+
+ GraphDef original_graph_def;
+ TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
+
+ std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
+ TF_ASSERT_OK(original_session->Create(original_graph_def));
+ std::vector<Tensor> original_outputs;
+ TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
+
+ GraphDef fused_graph_def;
+ TF_ASSERT_OK(
+ FoldBatchNorms(original_graph_def, {{}, {"output"}}, &fused_graph_def));
+
+ std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
+ TF_ASSERT_OK(fused_session->Create(fused_graph_def));
+ std::vector<Tensor> fused_outputs;
+ TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
+
+ test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
+
+ for (const NodeDef& node : fused_graph_def.node()) {
+ EXPECT_NE("Mul", node.op());
+ }
+ }
+
void TestFoldBatchNormsConv2DShared() {
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
@@ -202,6 +253,9 @@
TEST_F(FoldBatchNormsTest, TestFoldBatchNormsMatMul) {
TestFoldBatchNormsMatMul();
}
+TEST_F(FoldBatchNormsTest, TestFoldBatchNormsDepthwiseConv2dNative) {
+ TestFoldBatchNormsDepthwiseConv2dNative();
+}
} // namespace graph_transforms
} // namespace tensorflow
diff --git a/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc b/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc
index fd546f8..532b460 100644
--- a/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc
+++ b/tensorflow/tools/graph_transforms/fold_old_batch_norms.cc
@@ -109,24 +109,29 @@
const string& conv_output_name,
std::vector<NodeDef>* new_nodes) {
const NodeDef& conv_node = conv_node_match.node;
- CHECK_EQ("Conv2D", conv_node.op());
+ // CHECK_EQ("Conv2D", conv_node.op());
const NodeDef& input_node = conv_node_match.inputs[0].node;
const NodeDef& weights_node = conv_node_match.inputs[1].node;
CHECK_EQ("Const", weights_node.op());
Tensor weights = GetNodeTensorAttr(weights_node, "value");
- const int64 weights_cols = weights.shape().dim_size(3);
+ int64 weights_cols;
+ if (conv_node.op() == "Conv2D") {
+ weights_cols = weights.shape().dim_size(3);
+ } else if (conv_node.op() == "DepthwiseConv2dNative") {
+ weights_cols = weights.shape().dim_size(2) * weights.shape().dim_size(3);
+ } else {
+ weights_cols = weights.shape().dim_size(1);
+ }
CHECK_EQ(weights_cols, scale_values.size());
// Multiply the original weights by the scale vector.
- auto weights_matrix = weights.flat_inner_dims<float>();
+ auto weights_vector = weights.flat<float>();
Tensor scaled_weights(DT_FLOAT, weights.shape());
- auto scaled_weights_matrix = scaled_weights.flat_inner_dims<float>();
- for (int64 row = 0; row < weights_matrix.dimension(0); ++row) {
- for (int64 col = 0; col < weights_cols; ++col) {
- scaled_weights_matrix(row, col) =
- weights_matrix(row, col) * scale_values[col];
- }
+ auto scaled_weights_vector = scaled_weights.flat<float>();
+ for (int64 row = 0; row < weights_vector.dimension(0); ++row) {
+ scaled_weights_vector(row) =
+ weights_vector(row) * scale_values[row % weights_cols];
}
// Figure out the remaining bias to add on.
Tensor bias_offset(DT_FLOAT, {weights_cols});
@@ -158,7 +163,7 @@
NodeDef bias_add_node;
bias_add_node.set_op("BiasAdd");
bias_add_node.set_name(conv_output_name);
- if (conv_node.attr().count("data_format") > 0) {
+ if (!conv_node.attr().count("data_format")) {
CopyNodeAttr(conv_node, "data_format", "data_format", &bias_add_node);
}
CopyNodeAttr(conv_node, "T", "T", &bias_add_node);
@@ -185,7 +190,7 @@
}
Status FuseBatchNormWithBatchToSpace(const NodeMatch& match,
- std::vector<NodeDef>* new_nodes) {
+ std::vector<NodeDef>* new_nodes) {
// Calculate the scale and offset values to apply.
std::vector<float> scale_values;
std::vector<float> offset_values;
@@ -200,9 +205,8 @@
const NodeDef& conv_node = conv_node_match.node;
string biasadd_name = conv_node.name() + "/biasadd";
- TF_RETURN_IF_ERROR(
- FuseScaleOffsetToConvWeights(scale_values, offset_values, conv_node_match,
- biasadd_name , new_nodes));
+ TF_RETURN_IF_ERROR(FuseScaleOffsetToConvWeights(
+ scale_values, offset_values, conv_node_match, biasadd_name, new_nodes));
NodeDef new_batch_to_space_node = batch_to_space_node;
// reuse batch_norm node name
@@ -292,7 +296,7 @@
current_graph_def, // clang-format off
{"BatchNormWithGlobalNormalization|FusedBatchNorm", // batch_norm_node
{
- {"Conv2D", // conv_node
+ {"Conv2D|DepthwiseConv2dNative", // conv_node
{
{"*"}, // input_node
{"Const"}, // weights_node
@@ -325,7 +329,7 @@
{
{"BatchToSpaceND", // batch_to_space_node
{
- {"Conv2D", // conv_node
+ {"Conv2D|DepthwiseConv2dNative", // conv_node
{
{"*"}, // input_node
{"Const"}, // weights_node
@@ -363,13 +367,13 @@
{
{"ConcatV2|Concat", // concat two conv2d.
{
- {"Conv2D", // conv_node
+ {"Conv2D|DepthwiseConv2dNative", // conv_node
{
{"*"}, // input_node
{"Const"}, // weights_node
}
},
- {"Conv2D", // conv_node
+ {"Conv2D|DepthwiseConv2dNative", // conv_node
{
{"*"}, // input_node
{"Const"}, // weights_node
diff --git a/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc b/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc
index 6c71749..c5fa9b1 100644
--- a/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc
+++ b/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc
@@ -121,6 +121,84 @@
}
}
+ void TestFoldOldBatchNormsAfterDepthwiseConv2dNative() {
+ auto root = tensorflow::Scope::NewRootScope();
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+
+ Tensor input_data(DT_FLOAT, TensorShape({1, 1, 6, 2}));
+ test::FillValues<float>(
+ &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
+ -5.0f, -3.0f, -6.0f});
+ Output input_op =
+ Const(root.WithOpName("input_op"), Input::Initializer(input_data));
+
+ Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2}));
+ test::FillValues<float>(&weights_data,
+ {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
+ Output weights_op =
+ Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
+
+ Output conv_op = DepthwiseConv2dNative(root.WithOpName("conv_op"), input_op,
+ weights_op, {1, 1, 1, 1}, "VALID");
+
+ Tensor mean_data(DT_FLOAT, TensorShape({4}));
+ test::FillValues<float>(&mean_data, {10.0f, 20.0f, 30.0f, 40.0f});
+ Output mean_op =
+ Const(root.WithOpName("mean_op"), Input::Initializer(mean_data));
+
+ Tensor variance_data(DT_FLOAT, TensorShape({4}));
+ test::FillValues<float>(&variance_data, {0.25f, 0.5f, 0.75f, 1.0f});
+ Output variance_op = Const(root.WithOpName("variance_op"),
+ Input::Initializer(variance_data));
+
+ Tensor beta_data(DT_FLOAT, TensorShape({4}));
+ test::FillValues<float>(&beta_data, {0.1f, 0.6f, 1.1f, 1.6f});
+ Output beta_op =
+ Const(root.WithOpName("beta_op"), Input::Initializer(beta_data));
+
+ Tensor gamma_data(DT_FLOAT, TensorShape({4}));
+ test::FillValues<float>(&gamma_data, {1.0f, 2.0f, 3.0f, 4.0f});
+ Output gamma_op =
+ Const(root.WithOpName("gamma_op"), Input::Initializer(gamma_data));
+
+ GraphDef original_graph_def;
+ TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
+
+ NodeDef batch_norm_node;
+ batch_norm_node.set_op("BatchNormWithGlobalNormalization");
+ batch_norm_node.set_name("output");
+ AddNodeInput("conv_op", &batch_norm_node);
+ AddNodeInput("mean_op", &batch_norm_node);
+ AddNodeInput("variance_op", &batch_norm_node);
+ AddNodeInput("beta_op", &batch_norm_node);
+ AddNodeInput("gamma_op", &batch_norm_node);
+ SetNodeAttr("T", DT_FLOAT, &batch_norm_node);
+ SetNodeAttr("variance_epsilon", 0.00001f, &batch_norm_node);
+ SetNodeAttr("scale_after_normalization", false, &batch_norm_node);
+ *(original_graph_def.mutable_node()->Add()) = batch_norm_node;
+ original_graph_def.mutable_versions()->set_producer(8);
+
+ std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
+ TF_ASSERT_OK(original_session->Create(original_graph_def));
+ std::vector<Tensor> original_outputs;
+ TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
+
+ GraphDef fused_graph_def;
+ TF_ASSERT_OK(FoldOldBatchNorms(original_graph_def, {{}, {"output"}},
+ &fused_graph_def));
+
+ std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
+ TF_ASSERT_OK(fused_session->Create(fused_graph_def));
+ std::vector<Tensor> fused_outputs;
+ TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
+
+ test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 1e-5);
+
+ for (const NodeDef& node : fused_graph_def.node()) {
+ EXPECT_NE("BatchNormWithGlobalNormalization", node.op());
+ }
+ }
+
void TestFoldFusedBatchNorms() {
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
@@ -198,6 +276,83 @@
}
}
+ void TestFoldFusedBatchNormsAfterDepthwiseConv2dNative() {
+ auto root = tensorflow::Scope::NewRootScope();
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+
+ Tensor input_data(DT_FLOAT, TensorShape({1, 1, 6, 2}));
+ test::FillValues<float>(
+ &input_data, {1.0f, 4.0f, 2.0f, 5.0f, 3.0f, 6.0f, -1.0f, -4.0f, -2.0f,
+ -5.0f, -3.0f, -6.0f});
+ Output input_op =
+ Const(root.WithOpName("input_op"), Input::Initializer(input_data));
+
+ Tensor weights_data(DT_FLOAT, TensorShape({1, 2, 2, 2}));
+ test::FillValues<float>(&weights_data,
+ {1.0f, 2.0f, 3.0f, 4.0f, 0.1f, 0.2f, 0.3f, 0.4f});
+ Output weights_op =
+ Const(root.WithOpName("weights_op"), Input::Initializer(weights_data));
+
+ Output conv_op = DepthwiseConv2dNative(root.WithOpName("conv_op"), input_op,
+ weights_op, {1, 1, 1, 1}, "VALID");
+
+ Tensor mean_data(DT_FLOAT, TensorShape({4}));
+ test::FillValues<float>(&mean_data, {10.0f, 20.0f, 30.0f, 40.0f});
+ Output mean_op =
+ Const(root.WithOpName("mean_op"), Input::Initializer(mean_data));
+
+ Tensor variance_data(DT_FLOAT, TensorShape({4}));
+ test::FillValues<float>(&variance_data, {0.25f, 0.5f, 0.75f, 1.0f});
+ Output variance_op = Const(root.WithOpName("variance_op"),
+ Input::Initializer(variance_data));
+
+ Tensor beta_data(DT_FLOAT, TensorShape({4}));
+ test::FillValues<float>(&beta_data, {0.1f, 0.6f, 1.1f, 1.6f});
+ Output beta_op =
+ Const(root.WithOpName("beta_op"), Input::Initializer(beta_data));
+
+ Tensor gamma_data(DT_FLOAT, TensorShape({4}));
+ test::FillValues<float>(&gamma_data, {1.0f, 2.0f, 3.0f, 4.0f});
+ Output gamma_op =
+ Const(root.WithOpName("gamma_op"), Input::Initializer(gamma_data));
+
+ GraphDef original_graph_def;
+ TF_ASSERT_OK(root.ToGraphDef(&original_graph_def));
+
+ NodeDef batch_norm_node;
+ batch_norm_node.set_op("FusedBatchNorm");
+ batch_norm_node.set_name("output");
+ AddNodeInput("conv_op", &batch_norm_node);
+ AddNodeInput("gamma_op", &batch_norm_node);
+ AddNodeInput("beta_op", &batch_norm_node);
+ AddNodeInput("mean_op", &batch_norm_node);
+ AddNodeInput("variance_op", &batch_norm_node);
+ SetNodeAttr("T", DT_FLOAT, &batch_norm_node);
+ SetNodeAttr("epsilon", 0.00001f, &batch_norm_node);
+ SetNodeAttr("is_training", false, &batch_norm_node);
+ *(original_graph_def.mutable_node()->Add()) = batch_norm_node;
+
+ std::unique_ptr<Session> original_session(NewSession(SessionOptions()));
+ TF_ASSERT_OK(original_session->Create(original_graph_def));
+ std::vector<Tensor> original_outputs;
+ TF_ASSERT_OK(original_session->Run({}, {"output"}, {}, &original_outputs));
+
+ GraphDef fused_graph_def;
+ TF_ASSERT_OK(FoldOldBatchNorms(original_graph_def, {{}, {"output"}},
+ &fused_graph_def));
+
+ std::unique_ptr<Session> fused_session(NewSession(SessionOptions()));
+ TF_ASSERT_OK(fused_session->Create(fused_graph_def));
+ std::vector<Tensor> fused_outputs;
+ TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
+
+ test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 2e-5);
+
+ for (const NodeDef& node : fused_graph_def.node()) {
+ EXPECT_NE("FusedBatchNorm", node.op());
+ }
+ }
+
void TestFoldFusedBatchNormsWithConcat(const bool split) {
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
@@ -321,16 +476,17 @@
Tensor block_shape_data(DT_INT32, TensorShape({2}));
test::FillValues<int32>(&block_shape_data, {1, 2});
- Output block_shape_op =
- Const(root.WithOpName("block_shape_op"), Input::Initializer(block_shape_data));
+ Output block_shape_op = Const(root.WithOpName("block_shape_op"),
+ Input::Initializer(block_shape_data));
Tensor crops_data(DT_INT32, TensorShape({2, 2}));
test::FillValues<int32>(&crops_data, {0, 0, 0, 1});
Output crops_op =
Const(root.WithOpName("crops_op"), Input::Initializer(crops_data));
- Output batch_to_space_op = BatchToSpaceND(root.WithOpName("batch_to_space_op"),
- conv_op, block_shape_op, crops_data);
+ Output batch_to_space_op =
+ BatchToSpaceND(root.WithOpName("batch_to_space_op"), conv_op,
+ block_shape_op, crops_data);
Tensor mean_data(DT_FLOAT, TensorShape({2}));
test::FillValues<float>(&mean_data, {10.0f, 20.0f});
@@ -339,8 +495,8 @@
Tensor variance_data(DT_FLOAT, TensorShape({2}));
test::FillValues<float>(&variance_data, {0.25f, 0.5f});
- Output variance_op = Const(root.WithOpName("variance_op"),
- Input::Initializer(variance_data));
+ Output variance_op =
+ Const(root.WithOpName("variance_op"), Input::Initializer(variance_data));
Tensor beta_data(DT_FLOAT, TensorShape({2}));
test::FillValues<float>(&beta_data, {0.1f, 0.6f});
@@ -410,5 +566,14 @@
TestFoldFusedBatchNormsWithBatchToSpace();
}
+TEST_F(FoldOldBatchNormsTest, TestFoldOldBatchNormsAfterDepthwiseConv2dNative) {
+ TestFoldOldBatchNormsAfterDepthwiseConv2dNative();
+}
+
+TEST_F(FoldOldBatchNormsTest,
+ TestFoldFusedBatchNormsAfterDepthwiseConv2dNative) {
+ TestFoldFusedBatchNormsAfterDepthwiseConv2dNative();
+}
+
} // namespace graph_transforms
} // namespace tensorflow
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index 30a31a3..30ee516 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -77,6 +77,7 @@
"//tensorflow/python:meta_graph_testdata",
"//tensorflow/python:spectral_ops_test_util",
"//tensorflow/python:util_example_parser_configuration",
+ "//tensorflow/python/data/benchmarks:benchmark_base",
"//tensorflow/python/data/experimental/kernel_tests/serialization:dataset_serialization_test_base",
"//tensorflow/python/data/experimental/kernel_tests:stats_dataset_test_base",
"//tensorflow/python/data/kernel_tests:test_base",
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 1165a1a..e97041d 100755
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -140,12 +140,12 @@
tf_http_archive(
name = "eigen_archive",
build_file = clean_dep("//third_party:eigen.BUILD"),
- patch_file = clean_dep("//third_party/eigen3:gebp_neon.patch"),
- sha256 = "48678550a32665331d729be87076e576f2502fff325f5b6c2c78ebf7b1b22c7b",
- strip_prefix = "eigen-eigen-bcc817c0ba98",
+ patch_file = clean_dep("//third_party/eigen3:gpu_packet_math.patch"),
+ sha256 = "d1d2ac19b8ef386ad70b91932a90bfbc3014b801d14723b9c8373239128bd2dd",
+ strip_prefix = "eigen-eigen-1724f8760da8",
urls = [
- "https://mirror.bazel.build/bitbucket.org/eigen/eigen/get/bcc817c0ba98.tar.gz",
- "https://bitbucket.org/eigen/eigen/get/bcc817c0ba98.tar.gz",
+ "https://mirror.bazel.build/bitbucket.org/eigen/eigen/get/1724f8760da8.tar.gz",
+ "https://bitbucket.org/eigen/eigen/get/1724f8760da8.tar.gz",
],
)
@@ -514,11 +514,11 @@
tf_http_archive(
name = "llvm",
build_file = clean_dep("//third_party/llvm:llvm.autogenerated.BUILD"),
- sha256 = "48699c52e64e428388b0c7d2daa5197864f1f2425608a89fc15d494bf4bab0e0",
- strip_prefix = "llvm-66b8f683ab6b3be52fe606a0db0f377f9f66e170",
+ sha256 = "ec524f6c8e4d7514b1753131e844c73df4cef1a22d8773b0e43ccd9f813bbd48",
+ strip_prefix = "llvm-ba3936b0ea89935ab8ea943c1b372d69923691d9",
urls = [
- "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/66b8f683ab6b3be52fe606a0db0f377f9f66e170.tar.gz",
- "https://github.com/llvm-mirror/llvm/archive/66b8f683ab6b3be52fe606a0db0f377f9f66e170.tar.gz",
+ "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/ba3936b0ea89935ab8ea943c1b372d69923691d9.tar.gz",
+ "https://github.com/llvm-mirror/llvm/archive/ba3936b0ea89935ab8ea943c1b372d69923691d9.tar.gz",
],
)
@@ -719,16 +719,6 @@
)
tf_http_archive(
- name = "bazel_toolchains",
- sha256 = "07dfbe80638eb1fe681f7c07e61b34b579c6710c691e49ee90ccdc6e9e75ebbb",
- strip_prefix = "bazel-toolchains-9a111bd82161c1fbe8ed17a593ca1023fd941c70",
- urls = [
- "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/9a111bd82161c1fbe8ed17a593ca1023fd941c70.tar.gz",
- "https://github.com/bazelbuild/bazel-toolchains/archive/9a111bd82161c1fbe8ed17a593ca1023fd941c70.tar.gz",
- ],
- )
-
- tf_http_archive(
name = "arm_neon_2_x86_sse",
build_file = clean_dep("//third_party:arm_neon_2_x86_sse.BUILD"),
sha256 = "213733991310b904b11b053ac224fee2d4e0179e46b52fe7f8735b8831e04dcc",
diff --git a/third_party/eigen3/gebp_neon.patch b/third_party/eigen3/gebp_neon.patch
deleted file mode 100644
index d0022e9..0000000
--- a/third_party/eigen3/gebp_neon.patch
+++ /dev/null
@@ -1,11 +0,0 @@
---- a/Eigen/src/Core/products/GeneralBlockPanelKernel.h 2019-01-22 20:46:51.000000000 -0800
-+++ b/Eigen/src/Core/products/GeneralBlockPanelKernel.h 2019-01-25 13:48:49.000000000 -0800
-@@ -1031,7 +1031,7 @@
-
- EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacket& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<0>&) const
- {
-- c += a * b;
-+ c = vfmaq_n_f32(c, a, b);
- }
-
- EIGEN_STRONG_INLINE void madd(const LhsPacket& a, const RhsPacketx4& b, AccPacket& c, RhsPacket& /*tmp*/, const FixedInt<0>&) const
diff --git a/third_party/eigen3/gpu_packet_math.patch b/third_party/eigen3/gpu_packet_math.patch
new file mode 100644
index 0000000..0634753
--- /dev/null
+++ b/third_party/eigen3/gpu_packet_math.patch
@@ -0,0 +1,18 @@
+--- a/Eigen/src/Core/arch/GPU/PacketMath.h
++++ b/Eigen/src/Core/arch/GPU/PacketMath.h
+@@ -100,6 +100,7 @@
+ return make_double2(from, from);
+ }
+
++#if defined(EIGEN_CUDA_ARCH)
+ namespace {
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float bitwise_and(const float& a,
+@@ -211,6 +212,7 @@
+ pcmp_eq<double2>(const double2& a, const double2& b) {
+ return make_double2(eq_mask(a.x, b.x), eq_mask(a.y, b.y));
+ }
++#endif // EIGEN_CUDA_ARCH
+
+ template<> EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float4 plset<float4>(const float& a) {
+ return make_float4(a, a+1, a+2, a+3);
diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
index 223ea4d..8df6782 100644
--- a/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
+++ b/third_party/eigen3/unsupported/Eigen/CXX11/src/FixedPoint/PacketMathAVX2.h
@@ -178,25 +178,37 @@
struct unpacket_traits<Packet32q8i> {
typedef QInt8 type;
typedef Packet16q8i half;
- enum { size = 32, alignment = Aligned32 };
+ enum { size = 32, alignment = Aligned32, vectorizable = true };
+};
+template <>
+struct unpacket_traits<Packet16q8i> {
+ typedef QInt8 type;
+ typedef Packet16q8i half;
+ enum { size = 16, alignment = Aligned32, vectorizable = true };
};
template <>
struct unpacket_traits<Packet16q16i> {
typedef QInt16 type;
typedef Packet8q16i half;
- enum { size = 16, alignment = Aligned32 };
+ enum { size = 16, alignment = Aligned32, vectorizable = true };
+};
+template <>
+struct unpacket_traits<Packet8q16i> {
+ typedef QInt16 type;
+ typedef Packet8q16i half;
+ enum { size = 8, alignment = Aligned32, vectorizable = true };
};
template <>
struct unpacket_traits<Packet32q8u> {
typedef QUInt8 type;
typedef Packet16q8u half;
- enum { size = 32, alignment = Aligned32 };
+ enum { size = 32, alignment = Aligned32, vectorizable = true };
};
template <>
struct unpacket_traits<Packet8q32i> {
typedef QInt32 type;
typedef Packet4q32i half;
- enum { size = 8, alignment = Aligned32 };
+ enum { size = 8, alignment = Aligned32, vectorizable = true };
};
// Unaligned load
@@ -206,6 +218,11 @@
reinterpret_cast<const __m256i*>(from));
}
template <>
+EIGEN_STRONG_INLINE Packet16q8i ploadu<Packet16q8i>(const QInt8* from) {
+ EIGEN_DEBUG_UNALIGNED_LOAD return _mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(from));
+}
+template <>
EIGEN_STRONG_INLINE Packet32q8u ploadu<Packet32q8u>(const QUInt8* from) {
EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_si256(
reinterpret_cast<const __m256i*>(from));
@@ -215,6 +232,11 @@
EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_si256(
reinterpret_cast<const __m256i*>(from));
}
+template<>
+EIGEN_STRONG_INLINE Packet8q16i ploadu<Packet8q16i>(const QInt16* from) {
+ EIGEN_DEBUG_UNALIGNED_LOAD return _mm_loadu_si128(
+ reinterpret_cast<const __m128i*>(from));
+}
template <>
EIGEN_STRONG_INLINE Packet8q32i ploadu<Packet8q32i>(const QInt32* from) {
EIGEN_DEBUG_UNALIGNED_LOAD return _mm256_loadu_si256(
@@ -228,6 +250,11 @@
reinterpret_cast<const __m256i*>(from));
}
template <>
+EIGEN_STRONG_INLINE Packet16q8i pload<Packet16q8i>(const QInt8* from) {
+ EIGEN_DEBUG_ALIGNED_LOAD return _mm_load_si128(
+ reinterpret_cast<const __m128i*>(from));
+}
+template <>
EIGEN_STRONG_INLINE Packet32q8u pload<Packet32q8u>(const QUInt8* from) {
EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_si256(
reinterpret_cast<const __m256i*>(from));
@@ -238,6 +265,11 @@
reinterpret_cast<const __m256i*>(from));
}
template <>
+EIGEN_STRONG_INLINE Packet8q16i pload<Packet8q16i>(const QInt16* from) {
+ EIGEN_DEBUG_ALIGNED_LOAD return _mm_load_si128(
+ reinterpret_cast<const __m128i*>(from));
+}
+template <>
EIGEN_STRONG_INLINE Packet8q32i pload<Packet8q32i>(const QInt32* from) {
EIGEN_DEBUG_ALIGNED_LOAD return _mm256_load_si256(
reinterpret_cast<const __m256i*>(from));
@@ -250,6 +282,11 @@
reinterpret_cast<__m256i*>(to), from.val);
}
template <>
+EIGEN_STRONG_INLINE void pstoreu<QInt8>(QInt8* to, const Packet16q8i& from) {
+ EIGEN_DEBUG_UNALIGNED_STORE _mm_storeu_si128(
+ reinterpret_cast<__m128i*>(to), from.val);
+}
+template <>
EIGEN_STRONG_INLINE void pstoreu<QUInt8>(QUInt8* to, const Packet32q8u& from) {
EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_si256(
reinterpret_cast<__m256i*>(to), from.val);
@@ -260,6 +297,11 @@
reinterpret_cast<__m256i*>(to), from.val);
}
template <>
+EIGEN_STRONG_INLINE void pstoreu<QInt16>(QInt16* to, const Packet8q16i& from) {
+ EIGEN_DEBUG_UNALIGNED_STORE _mm_storeu_si128(
+ reinterpret_cast<__m128i*>(to), from.val);
+}
+template <>
EIGEN_STRONG_INLINE void pstoreu<QInt32>(QInt32* to, const Packet8q32i& from) {
EIGEN_DEBUG_UNALIGNED_STORE _mm256_storeu_si256(
reinterpret_cast<__m256i*>(to), from.val);
@@ -277,6 +319,11 @@
from.val);
}
template <>
+EIGEN_STRONG_INLINE void pstore<QInt16>(QInt16* to, const Packet8q16i& from) {
+ EIGEN_DEBUG_ALIGNED_STORE _mm_store_si128(reinterpret_cast<__m128i*>(to),
+ from.val);
+}
+template <>
EIGEN_STRONG_INLINE void pstore<QUInt8>(QUInt8* to, const Packet32q8u& from) {
EIGEN_DEBUG_ALIGNED_STORE _mm256_store_si256(reinterpret_cast<__m256i*>(to),
from.val);
@@ -286,6 +333,11 @@
EIGEN_DEBUG_ALIGNED_STORE _mm256_store_si256(reinterpret_cast<__m256i*>(to),
from.val);
}
+template <>
+EIGEN_STRONG_INLINE void pstore<QInt8>(QInt8* to, const Packet16q8i& from) {
+ EIGEN_DEBUG_ALIGNED_STORE _mm_store_si128(reinterpret_cast<__m128i*>(to),
+ from.val);
+}
// Extract first element.
template <>
diff --git a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl
index 8242380..57c7b63 100755
--- a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl
+++ b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_rocm.tpl
@@ -205,25 +205,25 @@
args, leftover = parser.parse_known_args(sys.argv[1:])
if args.x and args.x[0] == 'rocm':
+ # XXX use hipcc to link
+ if args.pass_exit_codes:
+ gpu_compiler_flags = [flag for flag in sys.argv[1:]
+ if not flag.startswith(('-pass-exit-codes'))]
+
+ # special handling for $ORIGIN
+ # - guard every argument with ''
+ modified_gpu_compiler_flags = []
+ for flag in gpu_compiler_flags:
+ modified_gpu_compiler_flags.append("'" + flag + "'")
+
+ if args.rocm_log: Log('Link with hipcc: %s' % (' '.join([HIPCC_PATH] + modified_gpu_compiler_flags)))
+ return subprocess.call([HIPCC_PATH] + modified_gpu_compiler_flags)
+
if args.rocm_log: Log('-x rocm')
leftover = [pipes.quote(s) for s in leftover]
if args.rocm_log: Log('using hipcc')
return InvokeHipcc(leftover, log=args.rocm_log)
- # XXX use hipcc to link
- if args.pass_exit_codes:
- gpu_compiler_flags = [flag for flag in sys.argv[1:]
- if not flag.startswith(('-pass-exit-codes'))]
-
- # special handling for $ORIGIN
- # - guard every argument with ''
- modified_gpu_compiler_flags = []
- for flag in gpu_compiler_flags:
- modified_gpu_compiler_flags.append("'" + flag + "'")
-
- if args.rocm_log: Log('Link with hipcc: %s' % (' '.join([HIPCC_PATH] + modified_gpu_compiler_flags)))
- return subprocess.call([HIPCC_PATH] + modified_gpu_compiler_flags)
-
# Strip our flags before passing through to the CPU compiler for files which
# are not -x rocm. We can't just pass 'leftover' because it also strips -x.
# We not only want to pass -x to the CPU compiler, but also keep it in its
diff --git a/third_party/hwloc/BUILD b/third_party/hwloc/BUILD
index 2f5d02b..2469c95 100644
--- a/third_party/hwloc/BUILD
+++ b/third_party/hwloc/BUILD
@@ -1 +1,7 @@
-# Dummy BUILD file to make this directory a package.
+# BUILD file to make this directory a package.
+
+licenses(["notice"])
+
+exports_files(
+ ["static-components.h"],
+)
diff --git a/third_party/hwloc/BUILD.bazel b/third_party/hwloc/BUILD.bazel
index 1f29301..5d636ca 100644
--- a/third_party/hwloc/BUILD.bazel
+++ b/third_party/hwloc/BUILD.bazel
@@ -31,11 +31,11 @@
src = "include/hwloc/autogen/config.h.in",
out = "include/hwloc/autogen/config.h",
substitutions = {
- "#undef HWLOC_VERSION": "#define HWLOC_VERSION \"2.0.3\"",
"#undef HWLOC_VERSION_MAJOR": "#define HWLOC_VERSION_MAJOR 2",
"#undef HWLOC_VERSION_MINOR": "#define HWLOC_VERSION_MINOR 0",
"#undef HWLOC_VERSION_RELEASE": "#define HWLOC_VERSION_RELEASE 3",
"#undef HWLOC_VERSION_GREEK": "#define HWLOC_VERSION_GREEK \"\"",
+ "#undef HWLOC_VERSION": "#define HWLOC_VERSION \"2.0.3\"",
"#undef HWLOC_LINUX_SYS": "#define HWLOC_LINUX_SYS 1",
"#undef hwloc_pid_t": "#define hwloc_pid_t pid_t",
"#undef hwloc_thread_t": "#define hwloc_thread_t pthread_t",
@@ -46,154 +46,168 @@
},
)
+_INCLUDE_PRIVATE_HWLOC_AUTOIGEN_CONFIG_H_COMMON_SUBS = {
+ "#undef HAVE_CLOCK_GETTIME": "#define HAVE_CLOCK_GETTIME 1",
+ "#undef HAVE_CTYPE_H": "#define HAVE_CTYPE_H 1",
+ "#undef HAVE_DECL_CTL_HW": "#define HAVE_DECL_CTL_HW 0",
+ "#undef HAVE_DECL_FABSF": "#define HAVE_DECL_FABSF 1",
+ "#undef HAVE_DECL_GETEXECNAME": "#define HAVE_DECL_GETEXECNAME 0",
+ "#undef HAVE_DECL_GETMODULEFILENAME": "#define HAVE_DECL_GETMODULEFILENAME 0",
+ "#undef HAVE_DECL_GETPROGNAME": "#define HAVE_DECL_GETPROGNAME 0",
+ "#undef HAVE_DECL_HW_NCPU": "#define HAVE_DECL_HW_NCPU 0",
+ "#undef HAVE_DECL_MODFF": "#define HAVE_DECL_MODFF 1",
+ "#undef HAVE_DECL_PTHREAD_GETAFFINITY_NP": "#define HAVE_DECL_PTHREAD_GETAFFINITY_NP 1",
+ "#undef HAVE_DECL_PTHREAD_SETAFFINITY_NP": "#define HAVE_DECL_PTHREAD_SETAFFINITY_NP 1",
+ "#undef HAVE_DECL_RUNNING_ON_VALGRIND": "#define HAVE_DECL_RUNNING_ON_VALGRIND 0",
+ "#undef HAVE_DECL_SCHED_GETCPU": "#define HAVE_DECL_SCHED_GETCPU 1",
+ "#undef HAVE_DECL_SNPRINTF": "#define HAVE_DECL_SNPRINTF 1",
+ "#undef HAVE_DECL_STRTOULL": "#define HAVE_DECL_STRTOULL 1",
+ "#undef HAVE_DECL__PUTENV": "#define HAVE_DECL__PUTENV 0",
+ "#undef HAVE_DECL__SC_LARGE_PAGESIZE": "#define HAVE_DECL__SC_LARGE_PAGESIZE 0",
+ "#undef HAVE_DECL__SC_NPROCESSORS_CONF": "#define HAVE_DECL__SC_NPROCESSORS_CONF 1",
+ "#undef HAVE_DECL__SC_NPROCESSORS_ONLN": "#define HAVE_DECL__SC_NPROCESSORS_ONLN 1",
+ "#undef HAVE_DECL__SC_NPROC_CONF": "#define HAVE_DECL__SC_NPROC_CONF 0",
+ "#undef HAVE_DECL__SC_NPROC_ONLN": "#define HAVE_DECL__SC_NPROC_ONLN 0",
+ "#undef HAVE_DECL__SC_PAGESIZE": "#define HAVE_DECL__SC_PAGESIZE 1",
+ "#undef HAVE_DECL__SC_PAGE_SIZE": "#define HAVE_DECL__SC_PAGE_SIZE 1",
+ "#undef HAVE_DECL__STRDUP": "#define HAVE_DECL__STRDUP 0",
+ "#undef HAVE_DIRENT_H": "#define HAVE_DIRENT_H 1",
+ "#undef HAVE_DLFCN_H": "#define HAVE_DLFCN_H 1",
+ "#undef HAVE_FFS": "#define HAVE_FFS 1",
+ "#undef HAVE_FFSL": "#define HAVE_FFSL 1",
+ "#undef HAVE_GETPAGESIZE": "#define HAVE_GETPAGESIZE 1",
+ "#undef HAVE_INTTYPES_H": "#define HAVE_INTTYPES_H 1",
+ "#undef HAVE_LANGINFO_H": "#define HAVE_LANGINFO_H 1",
+ "#undef HAVE_LOCALE_H": "#define HAVE_LOCALE_H 1",
+ "#undef HAVE_MALLOC_H": "#define HAVE_MALLOC_H 1",
+ "#undef HAVE_MEMALIGN": "#define HAVE_MEMALIGN 1",
+ "#undef HAVE_MEMORY_H": "#define HAVE_MEMORY_H 1",
+ "#undef HAVE_MKSTEMP": "#define HAVE_MKSTEMP 1",
+ "#undef HAVE_NL_LANGINFO": "#define HAVE_NL_LANGINFO 1",
+ "#undef HAVE_OPENAT": "#define HAVE_OPENAT 1",
+ "#undef HAVE_POSIX_MEMALIGN": "#define HAVE_POSIX_MEMALIGN 1",
+ "#undef HAVE_PROGRAM_INVOCATION_NAME": "#define HAVE_PROGRAM_INVOCATION_NAME 1",
+ "#undef HAVE_PTHREAD_T": "#define HAVE_PTHREAD_T 1",
+ "#undef HAVE_PUTWC": "#define HAVE_PUTWC 1",
+ "#undef HAVE_SETLOCALE": "#define HAVE_SETLOCALE 1",
+ "#undef HAVE_SSIZE_T": "#define HAVE_SSIZE_T 1",
+ "#undef HAVE_STDINT_H": "#define HAVE_STDINT_H 1",
+ "#undef HAVE_STDLIB_H": "#define HAVE_STDLIB_H 1",
+ "#undef HAVE_STRCASECMP": "#define HAVE_STRCASECMP 1",
+ "#undef HAVE_STRFTIME": "#define HAVE_STRFTIME 1",
+ "#undef HAVE_STRINGS_H": "#define HAVE_STRINGS_H 1",
+ "#undef HAVE_STRING_H": "#define HAVE_STRING_H 1",
+ "#undef HAVE_STRNCASECMP": "#define HAVE_STRNCASECMP 1",
+ "#undef HAVE_SYS_MMAN_H": "#define HAVE_SYS_MMAN_H 1",
+ "#undef HAVE_SYS_PARAM_H": "#define HAVE_SYS_PARAM_H 1",
+ "#undef HAVE_SYS_STAT_H": "#define HAVE_SYS_STAT_H 1",
+ "#undef HAVE_SYS_SYSCTL_H": "#define HAVE_SYS_SYSCTL_H 1",
+ "#undef HAVE_SYS_TYPES_H": "#define HAVE_SYS_TYPES_H 1",
+ "#undef HAVE_SYS_UTSNAME_H": "#define HAVE_SYS_UTSNAME_H 1",
+ "#undef HAVE_TIME_H": "#define HAVE_TIME_H 1",
+ "#undef HAVE_UNAME": "#define HAVE_UNAME 1",
+ "#undef HAVE_UNISTD_H": "#define HAVE_UNISTD_H 1",
+ "#undef HAVE_USELOCALE": "#define HAVE_USELOCALE 1",
+ "#undef HAVE_WCHAR_T": "#define HAVE_WCHAR_T 1",
+ "#undef HAVE_X11_KEYSYM_H": "#define HAVE_X11_KEYSYM_H 1",
+ "#undef HAVE_X11_XLIB_H": "#define HAVE_X11_XLIB_H 1",
+ "#undef HAVE_X11_XUTIL_H": "#define HAVE_X11_XUTIL_H 1",
+ "#undef HAVE_XLOCALE_H": "#define HAVE_XLOCALE_H 1",
+ "#undef HAVE___PROGNAME": "#define HAVE___PROGNAME 1",
+ "#undef HWLOC_C_HAVE_VISIBILITY": "#define HWLOC_C_HAVE_VISIBILITY 1",
+ "#undef HWLOC_HAVE_ATTRIBUTE": "#define HWLOC_HAVE_ATTRIBUTE 1",
+ "#undef HWLOC_HAVE_ATTRIBUTE_ALIGNED": "#define HWLOC_HAVE_ATTRIBUTE_ALIGNED 1",
+ "#undef HWLOC_HAVE_ATTRIBUTE_ALWAYS_INLINE": "#define HWLOC_HAVE_ATTRIBUTE_ALWAYS_INLINE 1",
+ "#undef HWLOC_HAVE_ATTRIBUTE_COLD": "#define HWLOC_HAVE_ATTRIBUTE_COLD 1",
+ "#undef HWLOC_HAVE_ATTRIBUTE_CONST": "#define HWLOC_HAVE_ATTRIBUTE_CONST 1",
+ "#undef HWLOC_HAVE_ATTRIBUTE_DEPRECATED": "#define HWLOC_HAVE_ATTRIBUTE_DEPRECATED 1",
+ "#undef HWLOC_HAVE_ATTRIBUTE_FORMAT": "#define HWLOC_HAVE_ATTRIBUTE_FORMAT 1",
+ "#undef HWLOC_HAVE_ATTRIBUTE_HOT": "#define HWLOC_HAVE_ATTRIBUTE_HOT 1",
+ "#undef HWLOC_HAVE_ATTRIBUTE_MALLOC": "#define HWLOC_HAVE_ATTRIBUTE_MALLOC 1",
+ "#undef HWLOC_HAVE_ATTRIBUTE_MAY_ALIAS": "#define HWLOC_HAVE_ATTRIBUTE_MAY_ALIAS 1",
+ "#undef HWLOC_HAVE_ATTRIBUTE_NONNULL": "#define HWLOC_HAVE_ATTRIBUTE_NONNULL 1",
+ "#undef HWLOC_HAVE_ATTRIBUTE_NORETURN": "#define HWLOC_HAVE_ATTRIBUTE_NORETURN 1",
+ "#undef HWLOC_HAVE_ATTRIBUTE_NO_INSTRUMENT_FUNCTION": "#define HWLOC_HAVE_ATTRIBUTE_NO_INSTRUMENT_FUNCTION 1",
+ "#undef HWLOC_HAVE_ATTRIBUTE_PACKED": "#define HWLOC_HAVE_ATTRIBUTE_PACKED 1",
+ "#undef HWLOC_HAVE_ATTRIBUTE_PURE": "#define HWLOC_HAVE_ATTRIBUTE_PURE 1",
+ "#undef HWLOC_HAVE_ATTRIBUTE_SENTINEL": "#define HWLOC_HAVE_ATTRIBUTE_SENTINEL 1",
+ "#undef HWLOC_HAVE_ATTRIBUTE_UNUSED": "#define HWLOC_HAVE_ATTRIBUTE_UNUSED 1",
+ "#undef HWLOC_HAVE_ATTRIBUTE_WARN_UNUSED_RESULT": "#define HWLOC_HAVE_ATTRIBUTE_WARN_UNUSED_RESULT 1",
+ "#undef HWLOC_HAVE_ATTRIBUTE_WEAK_ALIAS": "#define HWLOC_HAVE_ATTRIBUTE_WEAK_ALIAS 1",
+ "#undef HWLOC_HAVE_CPU_SET": "#define HWLOC_HAVE_CPU_SET 1",
+ "#undef HWLOC_HAVE_CPU_SET_S": "#define HWLOC_HAVE_CPU_SET_S 1",
+ "#undef HWLOC_HAVE_DECL_FFS": "#define HWLOC_HAVE_DECL_FFS 1",
+ "#undef HWLOC_HAVE_DECL_FFSL": "#define HWLOC_HAVE_DECL_FFSL 1",
+ "#undef HWLOC_HAVE_DECL_STRCASECMP": "#define HWLOC_HAVE_DECL_STRCASECMP 1",
+ "#undef HWLOC_HAVE_DECL_STRNCASECMP": "#define HWLOC_HAVE_DECL_STRNCASECMP 1",
+ "#undef HWLOC_HAVE_FFS": "#define HWLOC_HAVE_FFS 1",
+ "#undef HWLOC_HAVE_FFSL": "#define HWLOC_HAVE_FFSL 1",
+ "#undef HWLOC_HAVE_LIBTERMCAP": "#define HWLOC_HAVE_LIBTERMCAP 1",
+ "#undef HWLOC_HAVE_LINUXIO": "#define HWLOC_HAVE_LINUXIO 1",
+ "#undef HWLOC_HAVE_PTHREAD_MUTEX": "#define HWLOC_HAVE_PTHREAD_MUTEX 1",
+ "#undef HWLOC_HAVE_SCHED_SETAFFINITY": "#define HWLOC_HAVE_SCHED_SETAFFINITY 1",
+ "#undef HWLOC_HAVE_STDINT_H": "#define HWLOC_HAVE_STDINT_H 1",
+ "#undef HWLOC_HAVE_SYSCALL": "#define HWLOC_HAVE_SYSCALL 1",
+ "#undef HWLOC_HAVE_X11_KEYSYM": "#define HWLOC_HAVE_X11_KEYSYM 1",
+ "#undef HWLOC_HAVE_X86_CPUID": "#define HWLOC_HAVE_X86_CPUID 1",
+ "#undef HWLOC_LINUX_SYS": "#define HWLOC_LINUX_SYS 1",
+ "#undef HWLOC_SIZEOF_UNSIGNED_INT": "#define HWLOC_SIZEOF_UNSIGNED_INT 4",
+ "#undef HWLOC_SIZEOF_UNSIGNED_LONG": "#define HWLOC_SIZEOF_UNSIGNED_LONG 8",
+ "#undef HWLOC_SYM_PREFIX": "#define HWLOC_SYM_PREFIX hwloc_",
+ "#undef HWLOC_SYM_PREFIX_CAPS": "#define HWLOC_SYM_PREFIX_CAPS HWLOC_",
+ "#undef HWLOC_SYM_TRANSFORM": "#define HWLOC_SYM_TRANSFORM 0",
+ "#undef HWLOC_USE_NCURSES": "#define HWLOC_USE_NCURSES 1",
+ "#undef HWLOC_VERSION": "#define HWLOC_VERSION \"2.0.3\"",
+ "#undef HWLOC_VERSION_GREEK": "#define HWLOC_VERSION_GREEK \"\"",
+ "#undef HWLOC_VERSION_MAJOR": "#define HWLOC_VERSION_MAJOR 2",
+ "#undef HWLOC_VERSION_MINOR": "#define HWLOC_VERSION_MINOR 0",
+ "#undef HWLOC_VERSION_RELEASE": "#define HWLOC_VERSION_RELEASE 3",
+ "#undef HWLOC_X86_64_ARCH": "#define HWLOC_X86_64_ARCH 1",
+ "#undef LT_OBJDIR": "#define LT_OBJDIR \".libs/\"",
+ "#undef PACKAGE": "#define PACKAGE \"hwloc\"",
+ "#undef PACKAGE_BUGREPORT": "#define PACKAGE_BUGREPORT \"http://github.com/open-mpi/hwloc/i",
+ "#undef PACKAGE_NAME": "#define PACKAGE_NAME \"hwloc\"",
+ "#undef PACKAGE_STRING": "#define PACKAGE_STRING \"hwloc 2.0.3\"",
+ "#undef PACKAGE_TARNAME": "#define PACKAGE_TARNAME \"hwloc\"",
+ "#undef PACKAGE_URL": "#define PACKAGE_URL \"\"",
+ "#undef PACKAGE_VERSION": "#define PACKAGE_VERSION \"2.0.3\"",
+ "#undef SIZEOF_UNSIGNED_INT": "#define SIZEOF_UNSIGNED_INT 4",
+ "#undef SIZEOF_UNSIGNED_LONG": "#define SIZEOF_UNSIGNED_LONG 8",
+ "#undef SIZEOF_VOID_P": "#define SIZEOF_VOID_P 8",
+ "#undef STDC_HEADERS": "#define STDC_HEADERS 1",
+ "# undef _HPUX_SOURCE": "# define _HPUX_SOURCE 1",
+ "# undef _ALL_SOURCE": "# define _ALL_SOURCE 1",
+ "# undef _GNU_SOURCE": "# define _GNU_SOURCE 1",
+ "# undef _POSIX_PTHREAD_SEMANTICS": "# define _POSIX_PTHREAD_SEMANTICS 1",
+ "# undef _TANDEM_SOURCE": "# define _TANDEM_SOURCE 1",
+ "# undef __EXTENSIONS__": "# define __EXTENSIONS__ 1",
+ "#undef VERSION": "#define VERSION \"2.0.3\"",
+ "#undef _HPUX_SOURCE": "#define _HPUX_SOURCE 1",
+ "#undef hwloc_pid_t": "#define hwloc_pid_t pid_t",
+ "#undef hwloc_thread_t": "#define hwloc_thread_t pthread_t",
+}
+
+_INCLUDE_PRIVATE_HWLOC_AUTOIGEN_CONFIG_H_CUDA_SUBS = {
+ "#undef HAVE_CUDA": "#undef HAVE_CUDA 1",
+ "#undef HAVE_CUDA_H": "#undef HAVE_CUDA_H 1",
+ "#undef HAVE_CUDA_RUNTIME_API_H": "#undef HAVE_CUDA_RUNTIME_API_H 1",
+} + _INCLUDE_PRIVATE_HWLOC_AUTOIGEN_CONFIG_H_COMMON_SUBS
+
template_rule(
name = "include_private_hwloc_autogen__config_h",
src = "include/private/autogen/config.h.in",
out = "include/private/autogen/config.h",
- substitutions = {
- "#undef HAVE_CLOCK_GETTIME": "#define HAVE_CLOCK_GETTIME 1",
- "#undef HAVE_CTYPE_H": "#define HAVE_CTYPE_H 1",
- "#undef HAVE_DECL_CTL_HW": "#define HAVE_DECL_CTL_HW 0",
- "#undef HAVE_DECL_FABSF": "#define HAVE_DECL_FABSF 1",
- "#undef HAVE_DECL_GETEXECNAME": "#define HAVE_DECL_GETEXECNAME 0",
- "#undef HAVE_DECL_GETMODULEFILENAME": "#define HAVE_DECL_GETMODULEFILENAME 0",
- "#undef HAVE_DECL_GETPROGNAME": "#define HAVE_DECL_GETPROGNAME 0",
- "#undef HAVE_DECL_HW_NCPU": "#define HAVE_DECL_HW_NCPU 0",
- "#undef HAVE_DECL_MODFF": "#define HAVE_DECL_MODFF 1",
- "#undef HAVE_DECL_PTHREAD_GETAFFINITY_NP": "#define HAVE_DECL_PTHREAD_GETAFFINITY_NP 1",
- "#undef HAVE_DECL_PTHREAD_SETAFFINITY_NP": "#define HAVE_DECL_PTHREAD_SETAFFINITY_NP 1",
- "#undef HAVE_DECL_RUNNING_ON_VALGRIND": "#define HAVE_DECL_RUNNING_ON_VALGRIND 0",
- "#undef HAVE_DECL_SCHED_GETCPU": "#define HAVE_DECL_SCHED_GETCPU 1",
- "#undef HAVE_DECL_SNPRINTF": "#define HAVE_DECL_SNPRINTF 1",
- "#undef HAVE_DECL_STRTOULL": "#define HAVE_DECL_STRTOULL 1",
- "#undef HAVE_DECL__PUTENV": "#define HAVE_DECL__PUTENV 0",
- "#undef HAVE_DECL__SC_LARGE_PAGESIZE": "#define HAVE_DECL__SC_LARGE_PAGESIZE 0",
- "#undef HAVE_DECL__SC_NPROCESSORS_CONF": "#define HAVE_DECL__SC_NPROCESSORS_CONF 1",
- "#undef HAVE_DECL__SC_NPROCESSORS_ONLN": "#define HAVE_DECL__SC_NPROCESSORS_ONLN 1",
- "#undef HAVE_DECL__SC_NPROC_CONF": "#define HAVE_DECL__SC_NPROC_CONF 0",
- "#undef HAVE_DECL__SC_NPROC_ONLN": "#define HAVE_DECL__SC_NPROC_ONLN 0",
- "#undef HAVE_DECL__SC_PAGESIZE": "#define HAVE_DECL__SC_PAGESIZE 1",
- "#undef HAVE_DECL__SC_PAGE_SIZE": "#define HAVE_DECL__SC_PAGE_SIZE 1",
- "#undef HAVE_DECL__STRDUP": "#define HAVE_DECL__STRDUP 0",
- "#undef HAVE_DIRENT_H": "#define HAVE_DIRENT_H 1",
- "#undef HAVE_DLFCN_H": "#define HAVE_DLFCN_H 1",
- "#undef HAVE_FFS": "#define HAVE_FFS 1",
- "#undef HAVE_FFSL": "#define HAVE_FFSL 1",
- "#undef HAVE_GETPAGESIZE": "#define HAVE_GETPAGESIZE 1",
- "#undef HAVE_INTTYPES_H": "#define HAVE_INTTYPES_H 1",
- "#undef HAVE_LANGINFO_H": "#define HAVE_LANGINFO_H 1",
- "#undef HAVE_LOCALE_H": "#define HAVE_LOCALE_H 1",
- "#undef HAVE_MALLOC_H": "#define HAVE_MALLOC_H 1",
- "#undef HAVE_MEMALIGN": "#define HAVE_MEMALIGN 1",
- "#undef HAVE_MEMORY_H": "#define HAVE_MEMORY_H 1",
- "#undef HAVE_MKSTEMP": "#define HAVE_MKSTEMP 1",
- "#undef HAVE_NL_LANGINFO": "#define HAVE_NL_LANGINFO 1",
- "#undef HAVE_OPENAT": "#define HAVE_OPENAT 1",
- "#undef HAVE_POSIX_MEMALIGN": "#define HAVE_POSIX_MEMALIGN 1",
- "#undef HAVE_PROGRAM_INVOCATION_NAME": "#define HAVE_PROGRAM_INVOCATION_NAME 1",
- "#undef HAVE_PTHREAD_T": "#define HAVE_PTHREAD_T 1",
- "#undef HAVE_PUTWC": "#define HAVE_PUTWC 1",
- "#undef HAVE_SETLOCALE": "#define HAVE_SETLOCALE 1",
- "#undef HAVE_SSIZE_T": "#define HAVE_SSIZE_T 1",
- "#undef HAVE_STDINT_H": "#define HAVE_STDINT_H 1",
- "#undef HAVE_STDLIB_H": "#define HAVE_STDLIB_H 1",
- "#undef HAVE_STRCASECMP": "#define HAVE_STRCASECMP 1",
- "#undef HAVE_STRFTIME": "#define HAVE_STRFTIME 1",
- "#undef HAVE_STRINGS_H": "#define HAVE_STRINGS_H 1",
- "#undef HAVE_STRING_H": "#define HAVE_STRING_H 1",
- "#undef HAVE_STRNCASECMP": "#define HAVE_STRNCASECMP 1",
- "#undef HAVE_SYS_MMAN_H": "#define HAVE_SYS_MMAN_H 1",
- "#undef HAVE_SYS_PARAM_H": "#define HAVE_SYS_PARAM_H 1",
- "#undef HAVE_SYS_STAT_H": "#define HAVE_SYS_STAT_H 1",
- "#undef HAVE_SYS_SYSCTL_H": "#define HAVE_SYS_SYSCTL_H 1",
- "#undef HAVE_SYS_TYPES_H": "#define HAVE_SYS_TYPES_H 1",
- "#undef HAVE_SYS_UTSNAME_H": "#define HAVE_SYS_UTSNAME_H 1",
- "#undef HAVE_TIME_H": "#define HAVE_TIME_H 1",
- "#undef HAVE_UNAME": "#define HAVE_UNAME 1",
- "#undef HAVE_UNISTD_H": "#define HAVE_UNISTD_H 1",
- "#undef HAVE_USELOCALE": "#define HAVE_USELOCALE 1",
- "#undef HAVE_WCHAR_T": "#define HAVE_WCHAR_T 1",
- "#undef HAVE_X11_KEYSYM_H": "#define HAVE_X11_KEYSYM_H 1",
- "#undef HAVE_X11_XLIB_H": "#define HAVE_X11_XLIB_H 1",
- "#undef HAVE_X11_XUTIL_H": "#define HAVE_X11_XUTIL_H 1",
- "#undef HAVE_XLOCALE_H": "#define HAVE_XLOCALE_H 1",
- "#undef HAVE___PROGNAME": "#define HAVE___PROGNAME 1",
- "#undef HWLOC_C_HAVE_VISIBILITY": "#define HWLOC_C_HAVE_VISIBILITY 1",
- "#undef HWLOC_HAVE_ATTRIBUTE": "#define HWLOC_HAVE_ATTRIBUTE 1",
- "#undef HWLOC_HAVE_ATTRIBUTE_ALIGNED": "#define HWLOC_HAVE_ATTRIBUTE_ALIGNED 1",
- "#undef HWLOC_HAVE_ATTRIBUTE_ALWAYS_INLINE": "#define HWLOC_HAVE_ATTRIBUTE_ALWAYS_INLINE 1",
- "#undef HWLOC_HAVE_ATTRIBUTE_COLD": "#define HWLOC_HAVE_ATTRIBUTE_COLD 1",
- "#undef HWLOC_HAVE_ATTRIBUTE_CONST": "#define HWLOC_HAVE_ATTRIBUTE_CONST 1",
- "#undef HWLOC_HAVE_ATTRIBUTE_DEPRECATED": "#define HWLOC_HAVE_ATTRIBUTE_DEPRECATED 1",
- "#undef HWLOC_HAVE_ATTRIBUTE_FORMAT": "#define HWLOC_HAVE_ATTRIBUTE_FORMAT 1",
- "#undef HWLOC_HAVE_ATTRIBUTE_HOT": "#define HWLOC_HAVE_ATTRIBUTE_HOT 1",
- "#undef HWLOC_HAVE_ATTRIBUTE_MALLOC": "#define HWLOC_HAVE_ATTRIBUTE_MALLOC 1",
- "#undef HWLOC_HAVE_ATTRIBUTE_MAY_ALIAS": "#define HWLOC_HAVE_ATTRIBUTE_MAY_ALIAS 1",
- "#undef HWLOC_HAVE_ATTRIBUTE_NONNULL": "#define HWLOC_HAVE_ATTRIBUTE_NONNULL 1",
- "#undef HWLOC_HAVE_ATTRIBUTE_NORETURN": "#define HWLOC_HAVE_ATTRIBUTE_NORETURN 1",
- "#undef HWLOC_HAVE_ATTRIBUTE_NO_INSTRUMENT_FUNCTION": "#define HWLOC_HAVE_ATTRIBUTE_NO_INSTRUMENT_FUNCTION 1",
- "#undef HWLOC_HAVE_ATTRIBUTE_PACKED": "#define HWLOC_HAVE_ATTRIBUTE_PACKED 1",
- "#undef HWLOC_HAVE_ATTRIBUTE_PURE": "#define HWLOC_HAVE_ATTRIBUTE_PURE 1",
- "#undef HWLOC_HAVE_ATTRIBUTE_SENTINEL": "#define HWLOC_HAVE_ATTRIBUTE_SENTINEL 1",
- "#undef HWLOC_HAVE_ATTRIBUTE_UNUSED": "#define HWLOC_HAVE_ATTRIBUTE_UNUSED 1",
- "#undef HWLOC_HAVE_ATTRIBUTE_WARN_UNUSED_RESULT": "#define HWLOC_HAVE_ATTRIBUTE_WARN_UNUSED_RESULT 1",
- "#undef HWLOC_HAVE_ATTRIBUTE_WEAK_ALIAS": "#define HWLOC_HAVE_ATTRIBUTE_WEAK_ALIAS 1",
- "#undef HWLOC_HAVE_CPU_SET": "#define HWLOC_HAVE_CPU_SET 1",
- "#undef HWLOC_HAVE_CPU_SET_S": "#define HWLOC_HAVE_CPU_SET_S 1",
- "#undef HWLOC_HAVE_DECL_FFS": "#define HWLOC_HAVE_DECL_FFS 1",
- "#undef HWLOC_HAVE_DECL_FFSL": "#define HWLOC_HAVE_DECL_FFSL 1",
- "#undef HWLOC_HAVE_DECL_STRCASECMP": "#define HWLOC_HAVE_DECL_STRCASECMP 1",
- "#undef HWLOC_HAVE_DECL_STRNCASECMP": "#define HWLOC_HAVE_DECL_STRNCASECMP 1",
- "#undef HWLOC_HAVE_FFS": "#define HWLOC_HAVE_FFS 1",
- "#undef HWLOC_HAVE_FFSL": "#define HWLOC_HAVE_FFSL 1",
- "#undef HWLOC_HAVE_LIBTERMCAP": "#define HWLOC_HAVE_LIBTERMCAP 1",
- "#undef HWLOC_HAVE_LINUXIO": "#define HWLOC_HAVE_LINUXIO 1",
- "#undef HWLOC_HAVE_PTHREAD_MUTEX": "#define HWLOC_HAVE_PTHREAD_MUTEX 1",
- "#undef HWLOC_HAVE_SCHED_SETAFFINITY": "#define HWLOC_HAVE_SCHED_SETAFFINITY 1",
- "#undef HWLOC_HAVE_STDINT_H": "#define HWLOC_HAVE_STDINT_H 1",
- "#undef HWLOC_HAVE_SYSCALL": "#define HWLOC_HAVE_SYSCALL 1",
- "#undef HWLOC_HAVE_X11_KEYSYM": "#define HWLOC_HAVE_X11_KEYSYM 1",
- "#undef HWLOC_HAVE_X86_CPUID": "#define HWLOC_HAVE_X86_CPUID 1",
- "#undef HWLOC_LINUX_SYS": "#define HWLOC_LINUX_SYS 1",
- "#undef HWLOC_SIZEOF_UNSIGNED_INT": "#define HWLOC_SIZEOF_UNSIGNED_INT 4",
- "#undef HWLOC_SIZEOF_UNSIGNED_LONG": "#define HWLOC_SIZEOF_UNSIGNED_LONG 8",
- "#undef HWLOC_SYM_PREFIX": "#define HWLOC_SYM_PREFIX hwloc_",
- "#undef HWLOC_SYM_PREFIX_CAPS": "#define HWLOC_SYM_PREFIX_CAPS HWLOC_",
- "#undef HWLOC_SYM_TRANSFORM": "#define HWLOC_SYM_TRANSFORM 0",
- "#undef HWLOC_USE_NCURSES": "#define HWLOC_USE_NCURSES 1",
- "#undef HWLOC_VERSION": "#define HWLOC_VERSION \"2.0.3\"",
- "#undef HWLOC_VERSION_GREEK": "#define HWLOC_VERSION_GREEK \"\"",
- "#undef HWLOC_VERSION_MAJOR": "#define HWLOC_VERSION_MAJOR 2",
- "#undef HWLOC_VERSION_MINOR": "#define HWLOC_VERSION_MINOR 0",
- "#undef HWLOC_VERSION_RELEASE": "#define HWLOC_VERSION_RELEASE 3",
- "#undef HWLOC_X86_64_ARCH": "#define HWLOC_X86_64_ARCH 1",
- "#undef LT_OBJDIR": "#define LT_OBJDIR \".libs/\"",
- "#undef PACKAGE": "#define PACKAGE \"hwloc\"",
- "#undef PACKAGE_BUGREPORT": "#define PACKAGE_BUGREPORT \"http://github.com/open-mpi/hwloc/i",
- "#undef PACKAGE_NAME": "#define PACKAGE_NAME \"hwloc\"",
- "#undef PACKAGE_STRING": "#define PACKAGE_STRING \"hwloc 2.0.3\"",
- "#undef PACKAGE_TARNAME": "#define PACKAGE_TARNAME \"hwloc\"",
- "#undef PACKAGE_URL": "#define PACKAGE_URL \"\"",
- "#undef PACKAGE_VERSION": "#define PACKAGE_VERSION \"2.0.3\"",
- "#undef SIZEOF_UNSIGNED_INT": "#define SIZEOF_UNSIGNED_INT 4",
- "#undef SIZEOF_UNSIGNED_LONG": "#define SIZEOF_UNSIGNED_LONG 8",
- "#undef SIZEOF_VOID_P": "#define SIZEOF_VOID_P 8",
- "#undef STDC_HEADERS": "#define STDC_HEADERS 1",
- "# undef _HPUX_SOURCE": "# define _HPUX_SOURCE 1",
- "# undef _ALL_SOURCE": "# define _ALL_SOURCE 1",
- "# undef _GNU_SOURCE": "# define _GNU_SOURCE 1",
- "# undef _POSIX_PTHREAD_SEMANTICS": "# define _POSIX_PTHREAD_SEMANTICS 1",
- "# undef _TANDEM_SOURCE": "# define _TANDEM_SOURCE 1",
- "# undef __EXTENSIONS__": "# define __EXTENSIONS__ 1",
- "#undef VERSION": "#define VERSION \"2.0.3\"",
- "#undef _HPUX_SOURCE": "#define _HPUX_SOURCE 1",
- "#undef hwloc_pid_t": "#define hwloc_pid_t pid_t",
- "#undef hwloc_thread_t": "#define hwloc_thread_t pthread_t",
- } + if_cuda({
- "#undef HAVE_CUDA": "#undef HAVE_CUDA 1",
- "#undef HAVE_CUDA_H": "#undef HAVE_CUDA_H 1",
- "#undef HAVE_CUDA_RUNTIME_API_H": "#undef HAVE_CUDA_RUNTIME_API_H 1",
- }),
+ substitutions = if_cuda(
+ _INCLUDE_PRIVATE_HWLOC_AUTOIGEN_CONFIG_H_CUDA_SUBS,
+ if_false = _INCLUDE_PRIVATE_HWLOC_AUTOIGEN_CONFIG_H_COMMON_SUBS,
+ ),
+)
+
+template_rule(
+ name = "move_static_components_h",
+ src = "@org_tensorflow//third_party/hwloc:static-components.h",
+ out = "hwloc/static-components.h",
+ substitutions = {},
)
cc_library(
@@ -247,6 +261,10 @@
"-parse_headers",
"-layering_check",
],
+ includes = [
+ "hwloc",
+ "include",
+ ],
deps = [],
)
diff --git a/third_party/hwloc/static-components.h b/third_party/hwloc/static-components.h
new file mode 100644
index 0000000..8cae42a
--- /dev/null
+++ b/third_party/hwloc/static-components.h
@@ -0,0 +1,26 @@
+/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_HWLOC_STATIC_COMPONENTS_H_
+#define THIRD_PARTY_HWLOC_STATIC_COMPONENTS_H_
+
+#include <private/internal-components.h>
+static const struct hwloc_component* hwloc_static_components[] = {
+ &hwloc_noos_component, &hwloc_xml_component,
+ &hwloc_synthetic_component, &hwloc_xml_nolibxml_component,
+ &hwloc_linux_component, &hwloc_linuxio_component,
+ &hwloc_x86_component, NULL};
+
+#endif // THIRD_PARTY_HWLOC_STATIC_COMPONENTS_H_
diff --git a/third_party/toolchains/preconfig/generate/archives.bzl b/third_party/toolchains/preconfig/generate/archives.bzl
index bafc7d4..d20432e 100644
--- a/third_party/toolchains/preconfig/generate/archives.bzl
+++ b/third_party/toolchains/preconfig/generate/archives.bzl
@@ -1,13 +1,12 @@
-load("//tensorflow:version_check.bzl", "parse_bazel_version")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
def bazel_toolchains_archive():
http_archive(
name = "bazel_toolchains",
- sha256 = "ee854b5de299138c1f4a2edb5573d22b21d975acfc7aa938f36d30b49ef97498",
- strip_prefix = "bazel-toolchains-37419a124bdb9af2fec5b99a973d359b6b899b61",
+ sha256 = "109a99384f9d08f9e75136d218ebaebc68cc810c56897aea2224c57932052d30",
+ strip_prefix = "bazel-toolchains-94d31935a2c94fe7e7c7379a0f3393e181928ff7",
urls = [
- "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/37419a124bdb9af2fec5b99a973d359b6b899b61.tar.gz",
- "https://github.com/bazelbuild/bazel-toolchains/archive/37419a124bdb9af2fec5b99a973d359b6b899b61.tar.gz",
+ "https://mirror.bazel.build/github.com/bazelbuild/bazel-toolchains/archive/94d31935a2c94fe7e7c7379a0f3393e181928ff7.tar.gz",
+ "https://github.com/bazelbuild/bazel-toolchains/archive/94d31935a2c94fe7e7c7379a0f3393e181928ff7.tar.gz",
],
)
diff --git a/third_party/toolchains/preconfig/generate/workspace.bzl b/third_party/toolchains/preconfig/generate/workspace.bzl
index 0495173..bce2d5b 100644
--- a/third_party/toolchains/preconfig/generate/workspace.bzl
+++ b/third_party/toolchains/preconfig/generate/workspace.bzl
@@ -1,7 +1,10 @@
load(
+ "@io_bazel_rules_docker//repositories:repositories.bzl",
+ container_repositories = "repositories",
+)
+load(
"@io_bazel_rules_docker//container:container.bzl",
"container_pull",
- container_repositories = "repositories",
)
load(":containers.bzl", "container_digests")